From 50c8a7122a38359fcadbe80facbeba29ae2bb5c9 Mon Sep 17 00:00:00 2001 From: xiyuren <761346811@qq.com> Date: Thu, 22 Aug 2024 10:06:37 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8B=89=E5=BC=BA=E8=83=BD=E7=94=A8=E7=9A=84?= =?UTF-8?q?=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 + DualConvMixer.py | 131 ++++++++ DualTrain.py | 163 ++++++++++ loss/__init__.py | 0 loss/losses.py | 66 +++++ loss/perceptual_similarity/__init__.py | 0 loss/perceptual_similarity/base_model.py | 53 ++++ loss/perceptual_similarity/dist_model.py | 280 ++++++++++++++++++ loss/perceptual_similarity/networks_basic.py | 184 ++++++++++++ loss/perceptual_similarity/perceptual_loss.py | 163 ++++++++++ .../pretrained_networks.py | 180 +++++++++++ .../weights/v0.1/alex.pth | Bin 0 -> 6009 bytes .../weights/v0.1/squeeze.pth | Bin 0 -> 10811 bytes .../weights/v0.1/vgg.pth | Bin 0 -> 7289 bytes network/AutoEncoder.py | 48 +++ network/__init__.py | 1 + network/channel.py | 59 ++++ network/discriminator.py | 95 ++++++ network/encoder.py | 127 ++++++++ network/generator.py | 195 ++++++++++++ network/hyper.py | 130 ++++++++ network/instance.py | 15 + 22 files changed, 1895 insertions(+) create mode 100644 DualConvMixer.py create mode 100644 DualTrain.py create mode 100644 loss/__init__.py create mode 100644 loss/losses.py create mode 100644 loss/perceptual_similarity/__init__.py create mode 100644 loss/perceptual_similarity/base_model.py create mode 100644 loss/perceptual_similarity/dist_model.py create mode 100644 loss/perceptual_similarity/networks_basic.py create mode 100644 loss/perceptual_similarity/perceptual_loss.py create mode 100644 loss/perceptual_similarity/pretrained_networks.py create mode 100644 loss/perceptual_similarity/weights/v0.1/alex.pth create mode 100644 loss/perceptual_similarity/weights/v0.1/squeeze.pth create mode 100644 loss/perceptual_similarity/weights/v0.1/vgg.pth create mode 100644 network/AutoEncoder.py create mode 100644 network/__init__.py create mode 100644 network/channel.py create mode 100644 network/discriminator.py create mode 100644 network/encoder.py create mode 100644 network/generator.py create mode 100644 network/hyper.py create mode 100644 network/instance.py diff --git a/.gitignore b/.gitignore index ed07724..da61b35 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,8 @@ __pycache__ *.json *.bin *.pth +*.txt +STL10Data +!loss/perceptual_similarity/weights/v0.1/vgg.pth +!loss/perceptual_similarity/weights/v0.1/squeeze.pth +!loss/perceptual_similarity/weights/v0.1/alex.pth \ No newline at end of file diff --git a/DualConvMixer.py b/DualConvMixer.py new file mode 100644 index 0000000..e40fcd5 --- /dev/null +++ b/DualConvMixer.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# class Residual(nn.Module): +# def __init__(self, fn): +# super().__init__() +# self.fn = fn + +# def forward(self, x): +# return self.fn(x) + x + + +# class ConvBlock(nn.Module): +# def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): +# super(ConvBlock, self).__init__() +# self.conv = nn.Sequential( +# Residual(nn.Sequential( +# nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding), +# nn.GELU(), +# nn.BatchNorm2d(in_channels) +# )), +# nn.Conv2d(in_channels, out_channels, kernel_size=1), +# nn.GELU(), +# nn.BatchNorm2d(out_channels) +# ) + +# def forward(self, x): +# return self.conv(x) + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): + super(ConvBlock, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + +class UpBlock(nn.Module): + def __init__(self, in_channels, out_channels,kernel_size=3, stride=1, padding=1): + super(UpBlock, self).__init__() + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + # self.upsample = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False) + self.conv = ConvBlock(in_channels, out_channels,kernel_size=kernel_size, stride=stride, padding=padding) + + def forward(self, x): + x = self.upsample(x) + return self.conv(x) + +# class DeconvBlock(nn.Module): +# def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, output_padding=0): +# super(DeconvBlock, self).__init__() +# self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding) +# self.bn = nn.BatchNorm2d(out_channels) +# self.relu = nn.ReLU(inplace=True) + +# def forward(self, x): +# return self.relu(self.bn(self.deconv(x))) + +class HourglassModel(nn.Module): + def __init__(self, input_channels=3, latent_dim=128): + super(HourglassModel, self).__init__() + + # Encoder (正卷积部分) + self.encoder = nn.Sequential( + ConvBlock(input_channels, 64), + ConvBlock(64, 64), + nn.MaxPool2d(2), + ConvBlock(64, 128), + ConvBlock(128, 128), + nn.MaxPool2d(2), + ConvBlock(128, 256), + ConvBlock(256, 256), + nn.MaxPool2d(2), + ConvBlock(256, 512), + ConvBlock(512, 512), + nn.MaxPool2d(2), + ConvBlock(512, latent_dim) + ) + + # Decoder + self.decoder = nn.Sequential( + UpBlock(latent_dim, 512), + UpBlock(512, 256), + UpBlock(256, 128), + UpBlock(128, 64), + nn.Conv2d(64, input_channels, kernel_size=3, stride=1, padding=1), + nn.Tanh() + ) + + # # Decoder (反卷积部分) + # self.decoder = nn.Sequential( + # DeconvBlock(latent_dim, 512, kernel_size=3, stride=2, padding=1, output_padding=1), + # DeconvBlock(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), + # DeconvBlock(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), + # DeconvBlock(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), + # nn.Conv2d(64, input_channels, kernel_size=3, stride=1, padding=1), + # nn.Tanh() + # ) + + def forward(self, x): + latent = self.encoder(x) + output = self.decoder(latent) + return output + + def encode(self, x): + return self.encoder(x) + + def decode(self, latent): + return self.decoder(latent) + +# 测试模型 +if __name__ == "__main__": + # 创建一个示例输入 + batch_size = 1 + channels = 3 + height, width = 32, 32 + x = torch.randn(batch_size, channels, height, width) + + # 初始化模型 + model = HourglassModel() + + # 前向传播 + output = model(x) + + print(model.encode(x).shape) + print(f"Input shape: {x.shape}") + print(f"Output shape: {output.shape}") + # print(f"Model architecture:\n{model}") \ No newline at end of file diff --git a/DualTrain.py b/DualTrain.py new file mode 100644 index 0000000..512b59d --- /dev/null +++ b/DualTrain.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +import os +from tqdm import tqdm +import matplotlib.pyplot as plt +from DualConvMixer import HourglassModel +from torchvision.models import vgg16 +from torchvision.transforms.functional import to_pil_image +from torch.optim.lr_scheduler import OneCycleLR +from network.AutoEncoder import AutoEncoder +import numpy as np +import torch.nn.functional as F +from torchvision.models import vgg16 +from loss.perceptual_similarity.perceptual_loss import PerceptualLoss +import kornia +import colour +import numpy as np + +# 设置设备 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# 创建保存结果的文件夹 +os.makedirs("results", exist_ok=True) +os.makedirs("checkpoints", exist_ok=True) + +# 数据预处理 + +transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Lambda(lambda x: kornia.color.rgb_to_lab(x.unsqueeze(0)).squeeze(0)) +]) + +batch_size = 128 +# 加载STL10数据集 +train_dataset = datasets.STL10(root='./STL10Data', split='train+unlabeled', download=True, transform=transform) +train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8) + +# 加载CIFAR-10数据集 +# train_dataset = datasets.CIFAR10(root='./CIFAR10RawData', train=True, download=True, transform=transform) +# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2) + +# 初始化模型 +# model = HourglassModel().to(device) +model = AutoEncoder(image_dims=(3, 256, 256), batch_size=batch_size, + C=16,activation='leaky_relu').to(device) + +perceptual_loss = PerceptualLoss(model='net-lin', net='alex', + colorspace='Lab', use_gpu=torch.cuda.is_available()).to(device) +# 优化器 +optimizer = optim.AdamW(model.parameters(), lr=0.001) + +import torch.nn.functional as F +from kornia.losses import ssim_loss #psnr_loss + +def lab_loss(output, target): + # Calculate PSNR loss + # psnr = psnr_loss(output, target, max_val=255.0) # Assuming 8-bit color depth + # psnr_loss_value = 1.0 - psnr / 100.0 # Convert PSNR to a loss (0 to 1 range) + + # Separate channel losses + l_loss = F.mse_loss(output[:, 0], target[:, 0]) / (100.0 ** 2) # Normalize by square of range + a_loss = F.mse_loss(output[:, 1], target[:, 1]) / (255.0 ** 2) # Normalize by square of range + b_loss = F.mse_loss(output[:, 2], target[:, 2]) / (255.0 ** 2) # Normalize by square of range + + # SSIM loss (applied on L channel) + ssim = ssim_loss(output[:, 0].unsqueeze(1), target[:, 0].unsqueeze(1), window_size=11) + + # Total Variation loss (you might want to normalize this too) + tv_loss = (torch.mean(torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:])) + + torch.mean(torch.abs(output[:, :, :-1, :] - output[:, :, 1:, :]))) / 255.0 + + high_freq_output = output - F.avg_pool2d(output, kernel_size=3, stride=1, padding=1) + high_freq_target = target - F.avg_pool2d(target, kernel_size=3, stride=1, padding=1) + high_freq_loss = F.mse_loss(high_freq_output, high_freq_target) + + # Combine losses (adjust weights as needed) + total_loss = l_loss + 0.5 * (a_loss + b_loss) + 0.1 * ssim + 0.01 * tv_loss+ 0.1 * high_freq_loss + return total_loss, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss # Return individual losses for monitoring + # total_loss = 0.5 * psnr_loss_value + 0.2 * (l_loss + 0.5 * (a_loss + b_loss)) + 0.1 * ssim + 0.01 * tv_loss + # return total_loss, psnr_loss_value, l_loss, a_loss, b_loss, ssim, tv_loss + + +# 训练函数 +def train(epoch): + model.train() + total_loss = 0 + for batch_idx, (data, _) in enumerate(tqdm(train_loader)): + data = data.to(device) + optimizer.zero_grad() + output = model(data) + + # Calculate perceptual loss and take the mean to get a scalar + loss_perceptual = perceptual_loss(output, data).mean() + + # MSE loss is already a scalar + # loss_mse = nn.MSELoss()(output, data) + + # Combine losses (you can adjust the weights) + # loss =loss_perceptual# + 0.5 * loss_mse + # LAB-specific loss + loss_lab, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss = lab_loss(output, data) + + # Combine losses + loss = 0.3 * loss_perceptual + 0.7 * loss_lab + + loss.backward() + optimizer.step() + + total_loss += loss.item() + + print(f'Epoch {epoch}, Batch {batch_idx}, Total Loss: {loss.item():.4f}, ' + f'DE: , L: {l_loss.item():.4f}, ' + f'a: {a_loss.item():.4f}, b: {b_loss.item():.4f}, ' + f'SSIM: {ssim.item():.4f}, TV: {tv_loss.item():.4f}, HF: {high_freq_loss.item():.4f}') + print(f'pLoss: {loss_perceptual.item():.4f}', + f'lr: {optimizer.param_groups[0]["lr"]:.8f}') + return total_loss / len(train_loader) + + + +# 保存图像对比 +def save_image_comparison(epoch, data, output): + fig, axes = plt.subplots(2, 5, figsize=(15, 6)) + for i in range(5): + # 在保存图像对比函数中 + axes[0, i].imshow(to_pil_image(kornia.color.lab_to_rgb(data[i].unsqueeze(0)).squeeze(0).cpu())) + axes[0, i].axis('off') + axes[0, i].set_title('Original') + axes[1, i].imshow(to_pil_image(kornia.color.lab_to_rgb(output[i].unsqueeze(0)).squeeze(0).cpu().clamp(0, 1))) + + # axes[1, i].imshow(to_pil_image(output[i].cpu().clamp(-1, 1))) + axes[1, i].axis('off') + axes[1, i].set_title('Reconstructed') + + plt.tight_layout() + plt.savefig(f'results/256*256LAB++stl10_epoch_{epoch}.png') + plt.close() + + +# 训练循环 +num_epochs = 50 + +for epoch in range(1, num_epochs + 1): + avg_loss= train(epoch) + print(f'Epoch {epoch}, Average Loss: {avg_loss:.4f}') + + # 保存模型检查点 + if epoch % 5 == 0: + torch.save(model.state_dict(), f'checkpoints/model_epoch_{epoch}.pth') + + # 生成并保存图像对比 + if epoch % 1 == 0: + model.eval() + with torch.no_grad(): + data = next(iter(train_loader))[0][:5].to(device) + output = model(data) + save_image_comparison(epoch, data, output) + +print("Training completed!") \ No newline at end of file diff --git a/loss/__init__.py b/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/loss/losses.py b/loss/losses.py new file mode 100644 index 0000000..1c918ba --- /dev/null +++ b/loss/losses.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from src.helpers.utils import get_scheduled_params + +def weighted_rate_loss(config, total_nbpp, total_qbpp, step_counter, ignore_schedule=False): + """ + Heavily penalize the rate with weight lambda_A >> lambda_B if it exceeds + some target r_t, otherwise penalize with lambda_B + """ + lambda_A = get_scheduled_params(config.lambda_A, config.lambda_schedule, step_counter, ignore_schedule) + lambda_B = get_scheduled_params(config.lambda_B, config.lambda_schedule, step_counter, ignore_schedule) + + assert lambda_A > lambda_B, "Expected lambda_A > lambda_B, got (A) {} <= (B) {}".format( + lambda_A, lambda_B) + + target_bpp = get_scheduled_params(config.target_rate, config.target_schedule, step_counter, ignore_schedule) + + total_qbpp = total_qbpp.item() + if total_qbpp > target_bpp: + rate_penalty = lambda_A + else: + rate_penalty = lambda_B + weighted_rate = rate_penalty * total_nbpp + + return weighted_rate, float(rate_penalty) + +def _non_saturating_loss(D_real_logits, D_gen_logits, D_real=None, D_gen=None): + + D_loss_real = F.binary_cross_entropy_with_logits(input=D_real_logits, + target=torch.ones_like(D_real_logits)) + D_loss_gen = F.binary_cross_entropy_with_logits(input=D_gen_logits, + target=torch.zeros_like(D_gen_logits)) + D_loss = D_loss_real + D_loss_gen + + G_loss = F.binary_cross_entropy_with_logits(input=D_gen_logits, + target=torch.ones_like(D_gen_logits)) + + return D_loss, G_loss + +def _least_squares_loss(D_real, D_gen, D_real_logits=None, D_gen_logits=None): + D_loss_real = torch.mean(torch.square(D_real - 1.0)) + D_loss_gen = torch.mean(torch.square(D_gen)) + D_loss = 0.5 * (D_loss_real + D_loss_gen) + + G_loss = 0.5 * torch.mean(torch.square(D_gen - 1.0)) + + return D_loss, G_loss + +def gan_loss(gan_loss_type, disc_out, mode='generator_loss'): + + if gan_loss_type == 'non_saturating': + loss_fn = _non_saturating_loss + elif gan_loss_type == 'least_squares': + loss_fn = _least_squares_loss + else: + raise ValueError('Invalid GAN loss') + + D_loss, G_loss = loss_fn(D_real=disc_out.D_real, D_gen=disc_out.D_gen, + D_real_logits=disc_out.D_real_logits, D_gen_logits=disc_out.D_gen_logits) + + loss = G_loss if mode == 'generator_loss' else D_loss + + return loss diff --git a/loss/perceptual_similarity/__init__.py b/loss/perceptual_similarity/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/loss/perceptual_similarity/base_model.py b/loss/perceptual_similarity/base_model.py new file mode 100644 index 0000000..73882c1 --- /dev/null +++ b/loss/perceptual_similarity/base_model.py @@ -0,0 +1,53 @@ +import os +import torch +from torch.autograd import Variable + +class BaseModel(): + def __init__(self): + pass + + def name(self): + return 'BaseModel' + + def initialize(self, use_gpu=True, gpu_ids=[0]): + self.use_gpu = use_gpu + self.gpu_ids = gpu_ids + + def forward(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, path, network_label, epoch_label): + save_filename = f'{epoch_label}_net_{network_label}' + save_path = os.path.join(path, save_filename) + torch.save(network.state_dict(), save_path) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = f'{epoch_label}_net_{network_label}' + save_path = os.path.join(self.save_dir, save_filename) + print(f'Loading network from {save_path}') + network.load_state_dict(torch.load(save_path)) + + def get_image_paths(self): + return self.image_paths + + def save_done(self, flag=False): + np.save(os.path.join(self.save_dir, 'done_flag'),flag) + np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') + diff --git a/loss/perceptual_similarity/dist_model.py b/loss/perceptual_similarity/dist_model.py new file mode 100644 index 0000000..f728a40 --- /dev/null +++ b/loss/perceptual_similarity/dist_model.py @@ -0,0 +1,280 @@ + +from __future__ import absolute_import + +import sys +import numpy as np +import torch +from torch import nn +import os +from collections import OrderedDict +from torch.autograd import Variable +import itertools +from .base_model import BaseModel +from scipy.ndimage import zoom +import fractions +import functools +import skimage.transform +from tqdm import tqdm + + +from . import networks_basic as networks +from . import perceptual_loss + +class DistModel(BaseModel): + def name(self): + return self.model_name + + def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, + use_gpu=True, printNet=False, spatial=False, + is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): + ''' + INPUTS + model - ['net-lin'] for linearly calibrated network + ['net'] for off-the-shelf network + ['L2'] for L2 distance in Lab colorspace + ['SSIM'] for ssim in RGB colorspace + net - ['squeeze','alex','vgg'] + model_path - if None, will look in weights/[NET_NAME].pth + colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM + use_gpu - bool - whether or not to use a GPU + printNet - bool - whether or not to print network architecture out + spatial - bool - whether to output an array containing varying distances across spatial dimensions + is_train - bool - [True] for training mode + lr - float - initial learning rate + beta1 - float - initial momentum term for adam + version - 0.1 for latest, 0.0 was original (with a bug) + gpu_ids - int array - [0] by default, gpus to use + ''' + BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) + + self.model = model + self.net = net + self.is_train = is_train + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model_name = '%s [%s]'%(model,net) + + if(self.model == 'net-lin'): # pretrained net + linear layer + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, + use_dropout=True, spatial=spatial, version=version, lpips=True) + kw = {} + if not use_gpu: + kw['map_location'] = 'cpu' + if(model_path is None): + import inspect + model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) + + if(not is_train): + print('Loading model from: %s'%model_path) + self.net.load_state_dict(torch.load(model_path, **kw), strict=False) + + elif(self.model=='net'): # pretrained network + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) + elif(self.model in ['L2','l2']): + self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing + self.model_name = 'L2' + elif(self.model in ['DSSIM','dssim','SSIM','ssim']): + self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) + self.model_name = 'SSIM' + else: + raise ValueError("Model [%s] not recognized." % self.model) + + self.parameters = list(self.net.parameters()) + + if self.is_train: # training mode + # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) + self.rankLoss = networks.BCERankingLoss() + self.parameters += list(self.rankLoss.net.parameters()) + self.lr = lr + self.old_lr = lr + self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) + else: # test mode + self.net.eval() + + if(use_gpu): + self.net.to(gpu_ids[0]) + self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) + if(self.is_train): + self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 + + if(printNet): + print('---------- Networks initialized -------------') + networks.print_network(self.net) + print('-----------------------------------------------') + + def forward(self, in0, in1, retPerLayer=False): + ''' Function computes the distance between image patches in0 and in1 + INPUTS + in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] + OUTPUT + computed distances between in0 and in1 + ''' + + return self.net.forward(in0, in1, retPerLayer=retPerLayer) + + # ***** TRAINING FUNCTIONS ***** + def optimize_parameters(self): + self.forward_train() + self.optimizer_net.zero_grad() + self.backward_train() + self.optimizer_net.step() + self.clamp_weights() + + def clamp_weights(self): + for module in self.net.modules(): + if(hasattr(module, 'weight') and module.kernel_size==(1,1)): + module.weight.data = torch.clamp(module.weight.data,min=0) + + def set_input(self, data): + self.input_ref = data['ref'] + self.input_p0 = data['p0'] + self.input_p1 = data['p1'] + self.input_judge = data['judge'] + + if(self.use_gpu): + self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) + self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) + self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) + self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) + + self.var_ref = Variable(self.input_ref,requires_grad=True) + self.var_p0 = Variable(self.input_p0,requires_grad=True) + self.var_p1 = Variable(self.input_p1,requires_grad=True) + + def forward_train(self): # run forward pass + # print(self.net.module.scaling_layer.shift) + # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) + + self.d0 = self.forward(self.var_ref, self.var_p0) + self.d1 = self.forward(self.var_ref, self.var_p1) + self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) + + self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) + + self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) + + return self.loss_total + + def backward_train(self): + torch.mean(self.loss_total).backward() + + def compute_accuracy(self,d0,d1,judge): + ''' d0, d1 are Variables, judge is a Tensor ''' + d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) + self.old_lr = lr + +def score_2afc_dataset(data_loader, func, name=''): + ''' Function computes Two Alternative Forced Choice (2AFC) score using + distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return numpy array of length N + OUTPUTS + [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators + [1] - dictionary with following elements + d0s,d1s - N arrays containing distances between reference patch to perturbed patches + gts - N array in [0,1], preferred patch selected by human evaluators + (closer to "0" for left patch p0, "1" for right patch p1, + "0.6" means 60pct people preferred right patch, 40pct preferred left) + scores - N array in [0,1], corresponding to what percentage function agreed with humans + CONSTS + N - number of test triplets in data_loader + ''' + + d0s = [] + d1s = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() + d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() + gts+=data['judge'].cpu().numpy().flatten().tolist() + + d0s = np.array(d0s) + d1s = np.array(d1s) + gts = np.array(gts) + scores = (d0s= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): +# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): +# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) diff --git a/loss/perceptual_similarity/pretrained_networks.py b/loss/perceptual_similarity/pretrained_networks.py new file mode 100644 index 0000000..a70ebbe --- /dev/null +++ b/loss/perceptual_similarity/pretrained_networks.py @@ -0,0 +1,180 @@ +from collections import namedtuple +import torch +from torchvision import models as tv + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2,5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) + out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + if(num==18): + self.net = tv.resnet18(pretrained=pretrained) + elif(num==34): + self.net = tv.resnet34(pretrained=pretrained) + elif(num==50): + self.net = tv.resnet50(pretrained=pretrained) + elif(num==101): + self.net = tv.resnet101(pretrained=pretrained) + elif(num==152): + self.net = tv.resnet152(pretrained=pretrained) + self.N_slices = 5 + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out diff --git a/loss/perceptual_similarity/weights/v0.1/alex.pth b/loss/perceptual_similarity/weights/v0.1/alex.pth new file mode 100644 index 0000000000000000000000000000000000000000..1df9dfe62abb1fc89cc7f82b4e5fe886c979708e GIT binary patch literal 6009 zcma)9d0dTK_is`(5R&K?Dk(JUJbRy>wT_`8QE4#Un`TNWPlMz-I=H1mC?!Q`Fx?wb zhIBO;Zc)0(bY)1ADRYJlc~4J{_jNzNKYq{WQ+uEFUF-W@d+)W^bB;tIMK$d6HpkB4 z3-`JWF_$Pzf2=6|FXrS;yfha_Mnp#hM1=V)35tvgi3sPqQ7_f#xWz#}Q6bBMqBun{ zr)1)#7!n@M^>SA7>J=3n5gE-XJ1cl8g++uf;8dKIV!SlMLZYL?f_#F)14I16x!y@J zQiT91Z&*-3w3IeV)ip9OC^9I}J|rMom6P}86imE1MH8=qVIkp`=8GZ%gTl z76eCgN&(RkkpaQxJ~7cDVNt3+kwN}3Az^_&(Lv!+5s^MitW~*QCMI4gQBpU*1wov0 zkH4y8ScG4+M~_jJQ}I%@qBb^+^dBSOdV47c#02^YExA4^^5b2iXE@85ILkYam;T8| zoAiv*S+YyCJEt1t&8hX|X7%sf)Ptqc!OForyA@+OjK zM&D0Luh~Oy{crmI|4VOU#aJ=cmX_Rr@ALyx6eRL8`sosB^4;lwo5%#TqymF_1Z@5- zp#6h@jg=LVu7Ml;T|g&ALGHf-ax#IQiA-RKRA6Y2z_@=4=>8x;1k|>BgMRuhp!fd~ z&<_eqaOZ}Jx#3=Yd=>>o`vv+%`*9W-)e>vh&e_EksdNxIazX4%n4miesm?bwu%(=@zj~_uR&Qr{J$-tRCO5?a$Vs5s~IHyPHTS0Mi#oW&_=$9TcS*=-e z-ePW^4D|UCwBmfloL>*vT}1v-{?cCL-QDe^{ownJRp7gW?1D*KUU#qfb{mv_=p4QH zPqIQ*4l=Kj`#+v}Nn~L{g6Quj5hP9P0kmeVLBW_ojO*N3GOsR|e4aiG+PkLPihLn5t&VvoS@UQeiIvmB-!xX>*IF0Bg{Zxx8*g44hL_b+7=6oOFv&*@ z-L{UcTbZr`5yr;k^F9{$wCGPAJD9TNdx>^$EnD_eFxA@_23Ctq1$*3)kLivxJSQaK z(?$^68~%c7`*`fObwBxJ2jl0oOVB<4bdllnQf!@wL!(M~pBeoJk;S`fd7kq1pYU?V zCG@AOY0736<-4?aJG+yCiL2V#8=O)!P-2 zzKm-A4Z++Gy+tbCCm~mT8Qs#EAlU7q0?ANAyW{j}F)Z8bgC(w4_}KU~29K6b!B>Sz z-}iQdTDq;_zGFB;X!cF`8zf5H3!(0h1AuyW4Zo?3{A61Bfh#^zsg1hz8{`7az?jQD0dryR*@`5u-+USiAu>u(1Jc;*>iC+#*?w&%kK8fGg6&`v-&<3ug zdVC=cFApS+%nf{bL>2SmGQsz#4MgQ`AzA19(?IJ+;v`?)N4xkjWA!*eQ@kEgx@RDg$oK4?kJo)tcO(VfEh^x~1XiT;td5+wR1ywMn39KefP)H;b%9XgVQ)3_cH1YH!UbGPOGHHHVeqR zF=_bhbv#9TbzE`$39O7v!~37>m~*vJFwXTpUpN1ZG{#@7HiLOy7Lnvn#0?uy(Ds&5 zLXXmeI9K}uRW>UUBsmxhTTMG*?YdR)QE?Y}^Qwjt?i>@EoWVZ+sz;VoZz5mDY@#7u z=_F&9uV{#a9-mMDGpiYk=Gjy&HxwNYt>e$h|3VbMpH9&aaPIMYtiRD$dcePgEqwAj zlaPCr+WtBh4cPQSDd{eyuHwRx)2)%;EMqv1paOH<4^9 zp^aPa-~-K3Xq}LP+l@|;*`G6Q(|odd{qA#b3p>VqASLv4%{;zN6c0w?=}U1U)mlcR zdp(U@*4vDOrvOb?p3X}8WxSsZ8PEwJ0A^irY1_1WGdknF)I(G+3*I&%HaB?45b&yZ093 zF*P6S&BDm!@*UK=Bo%M>&cmmDits{01GE0>JRCi^73WQ~Mt=iE;&xm|n6u{`v6-$+ z-wgbl4D-vTyGKsu*E00dYo0sRI1ZCj8c5gP400wX1CwS=5Ut9b&HG$2Od+OUltr~k zed%U@4=OP{j0Fq#iIN>xFv>IkA#%yP@OW?rEl`{QH?oVF{hbr|y=$GajT~_@q(kxz ziOtp=oZ+7Pl)E?xvDt}vpIi%!&2X$eHQH9%h3ex~S?7d?DW2}7+P6ScdC zN%~}EYE-(9(YzE(&|(kzZs{-ddZB{MsR4Laa|quHgSSP{@_94Zu8wp%MSe1EnsSpg zCa=N^l4nfss3ZJ6!OEWZ%ElF)jJMt`FgRTZ`5sH~gI9~)&G(OqYv>JJ7AZ%UDO^W` z`8_=)l8GQt`V`$v??Fn#UpQ{z5F+UqU>B@sBy5d+O&*1mV?DDJBbJBZozQX|xzvmf zt9)*&HAo$ePwXILQ#|lpyA8tgcSIrcK9lA+5+9w&64ld>B&>Q4)vi+%#Yz-uhu>)o zKBY>|q$a@77ayT|i@RVdbCCEQ3=>(48Cb7+2>R{NBl+gpfFo>SVSWZo*Bwawn<5Dg z{X!$INskbG`93hsO@jtoWD?&`C+x0|*aULl-UFd31F-FlE+r~M@QY+SJF#jW(Tg=k zNq-BRE857PT}9ojzyqWXUtZ{Xeg;gNI1S=uINhtm4S?`X)W*>(gn~Whm+G&u&XSisy8k zaC(+ISZZA+n>U2Bbu$-$$-rCivv3PG#w^9QihN9&qCr2tRe-oh+rc|agKAygjD3Bm z@NU&*Hh#Dz@G_0sp!8LCZdE1*#Yn#rjLJrg(|XVMkM!SBRA{d)DlrHHhLpj~V-n7fsYXM%(ON2mcWl@+4|=*_JaftSH+o2h5}FMOtp z)G$%f_WfjQizZeo{Z0y(WMbZ;!)WmaSeHaA)E=D#oM$8mf+ z>|J+19D6O1Jgym!S1+d1q#_qIpTrOc-DynXkwjek+uu-nwgmGARKWgU6ls02EuA!a zB@B7g#kAfzifLniV}Fm`fX4bA%qD7z4nk8_cm8^UnlJgix2Q3ib%(HXXmFM za+CGL8M_@YT$~Ky@ke2TN(dNYDEs)Qky!OnmE63%240NTqhzZqRCaiRp>jX`<-%Cu zsxKw*xn74p_-PbW&mD+r&7SDHaU3=u(m}ZR7u@%oiu<0Fp~9YPBd=q9HE;b>zOERMPd(WVs` zFP3275(YA^>SD96Dh_B<0R2i2s6Lwmmjc>h;p8efu>3yJin)L;!5wU9TO#YL&;eK9 z9fLLds=@v1RWcF|f`hUydwG@ynN=AtY!Wx3rTtLIeRfR9O%6lbt@^_D*5hp3k`6Yb zupDxxJ%oW{L)gdm&2VJmB;1tv6to-7VQ#nr+Ncl0h5ECh@j|EFbgGC+0ikx!ofXkD zX(~89y+sXve#km(Dusr4O+2Ng1Z9W+vdhlW!8=U}u>8CVuAKJ(G@`3u_ncNxGt`1T z0a?&5LJR7>SL6N|O)#)42V;p6+Rf^OZEu~02EmABMcMGPtq~nnwHR;5DGC$*RKulK zYDglah@1Tz>tnp*k#e7V;Obo_H*d>#uGaUyIjPh}6(nL}ZQ8Kw?-TThvbinyK7@3y?_APDFUmjj` zepWK>F*^;lCrbd`lW0+ej&P3PA>E;U2@SF|X~XKx(3sr)F5PE<5IY9yz)+ z5GKxA47tPlVO_c|2E4Tu2Kfl-h{QH_Pse=nhqQ+s&~1WOp4Wt%Epypkro(YT4Fy+G z3MiI;Wv?I7rtjCkrYFyt(!x(4;7ivr@W{(!i%y?opEjh7)3w|Wm)|xDZOv2d0?IRq zMY@`B9~Hya*+uZ{vvg_~a}-0D-XJZfGvRLjRGewlpDvg^2!Gdd$0dWWLi!XF$n{)m z=X2q2yW`@Iz|`+#V`{ZwmX74ldqT#0R^3Vh4f6}Z;LKa17XOTm#Xdr_{b{IXvIjgH82o%=2b02_CArMG*L7lBvVKiMRZG{QY2(Zg+?@JRB6ya zn)C=wN=oyHW)1Swl|J9+`RjLHuj}r!-}|oj+H0@(+GiheQfn1Ler|F2ra!-Es-}6I z)TCF^U4KPFNfTGqh2de5VZLD@UW@(0BZ9&f2!Erl3Xw||`guhJ#rQ=CrA5MCCa%&! z3nGOwPQ6@ZA_Bw0BZaa~Qm(y1!WIMw<(zsga#acmii`~L^YU9TFUV(sP=3v#E(YH& zyCHtQkzHmZ6sCmF^9%QzXCLGnsUVc}6iS(R3Z+e4)k1<6SeS=~&GQQ}w=|FP3knE~ z6!!9s3=8)SH1}E*859zs;1%vScTrHtJg-Q<1rcH4UW-R52xUx6T;(FVhWP~e31z#- zD>#IN`9x0Z?o<%UxhhzqwT&eY!G#`Xo9Qc<&qs#yQ5ybTnbT@hkG`k42x(TfROEBpF5?FvGV`;_kgxY@z zbpC$`bh`-jx(RIlOQ8RM32ZDF3mfEt(BLnDVONNx{ud&to)C!}Ju?{GMKGkBfcIYl zqyG?CGrvPL^e=(&{||wQpI@SjaF|GF>e|~Y)GyL!o=>EY&@8Yk^@n$*{s5^5KB!reC;i$kromM8U3jfoI&{`z4>4A7Au6_OonOg`Mkq~-7 z{5NPR3+dcf)3 zOjbfykR#3g?T2!98F|H6H)ttbED}cbfYIGdJmC_NFs7$7wwvjXNQ6s8 zLQxO2tlLqL{#Mnl zji~FcrwI1?Q_s6Amdu~K?&>c4`|fWo5ZC6eyKdpi`>e#4QJ!2%=Z)D5rVrl=N-Bcj$60$!@|EQ> zVioAisuZ%UCW7|Uq4Y@UD#(u0L#zJDSo_ilH+}4dkHe?4(L?Vs;E;;bKAO<|-?M1@ z=Q<)&yq+%KV}XI=f6B=ADaJyq)e(}*HoO4G&1!hKgNvK3$Zn+4B;SO}zo71#mbqm~>_!?@`rVf9Mxk({csuCmbi8o2luo7Jd=56f?oU)~ASHD@%J+kO&11$+a}cw{ASeMm6ftYt@1 z9>l^aqcbFO`DRS{s6=1DGP-(<8vVV7qIogMysf2evp%sZsUN6j$p*T?a0VXUF3&x> zJBRAP)A$cwyWK*5^ZdSeOG?l^6S%%Imh_1AKDz9NATuaU{< z+%yUIs*WH{7g`zBCLd54k%O)$W1u}r2J5~I0A1<%jAv^)GuwYF&wsHYe)}m0S~974 zxUWA1ZgnRKgEygf9Aqb2spVyjl7Ud5a;|hNoBh1f)?|8KC=(L zrxTN^Pm#sU5` z<2CqYj~4dt9DvD{o9IuSNnn4xH*;r>DU$UkX>{{$Y$f*uZEcI-wr>*gD*g(=7hZxu zJ%jztoiybY6oDN2d7JZ{#uYJcy<(VUb^D0dJ1KN(x8r0&%^+f=3k{Mq1Ibt?PO4!B z%E&y#BZ)`I&&amFwaMCj3v4{J4&FE`<5mSJy!%;?zU~)~&j#i2ZY^JicSlQ-C7%6h zq*8w>vt|W6W{!*F;3X?V=(;noZ{A+KSazB|y?hNyUe|*@-2l=hH<*|wX*gKTpSiYN z3%-}s!16*H{{Fj9U`Ch?$x+k6vo#9hya<_H#T-5S(MJ3AH0s(*l6zp}ON6OjV*dst z=uyF4PewcV6|UbL47*M3(LQP>qjm8rLvB5wO`)49@mWk2$`)X=j65w&vjjRto2HN} zOycD^tV;W47-F+S(EH36_RyG9OtX?9qnU0`r&92DM41BGih@p`lE|;;FnN89gpvM@v>$r~`iSDbUbtSw=eD1qKxo(iv3ck}$s?3o*oKtduMPHV{Lep(bR*e*VP91==M&DvQ9~nb$ zi!bc8hL=3QcfO$U{3BC2{VeZt`w;4#ArI<=@L#NTh9L(N#cOr&<*uM97eXvzu#H?Da zA;wcp{o(5@YcckF?53+mkYlixNhn^)aO2i8#n2n*ug8D!XC;SheveD`m4l=NK0L~G zhWL+bZ(Eo{R0-%9eof#|h%L&a)7Z`13sAH})E7 zeSLuhYYgTdo>av%uR{fM*4&37r5nhp2Yt|OZ9VvT4#j#|d2wE=FfWEX&~ue z#0K=PWV1%taSL9JAdEtPI%I7bY(_(LGwiX~6dtmf0;*MZOu+Pw^wEnYaQuJ_o-Hf@bCDq}>AjknFlHMx?;3#qGyj1Td6|Oi z#cSE)nsZ^iVG_Bpw-#J(+y|)?U6?f9ghVHm!`;}sNJ_x9FIfa&d|x`-x7=~Kn9t#a8CN#=AgX_K#IWs2xCiz_ z&N>Aw_^H5r(X@qnP%!ti3)Mat3-cbjeX@l#XSBJ`4qvKMb!Xd?Jy$lK4VPlSJLg zW7cUtgY=>4G*a{c{%P69?J#o#jcqJi#N1(&4jg9&q|S$q1)qR#{h60|=@xwTOv60o zbmo`Sb|!rLZlJSzLBGu##q;^PYdz0wbu7qV$wSo#>f*lmk9p$wZwx+5Jw+d3NYh2$ zoHsHs^5%7>Uc>|4jZW0CULDI7uYqlhF<9Qo2Dh|zpza-w1-q(wE?0IjaAW{~hN&N; zXCB8?6vVN@`|WTe%8|)i&J){yxg#2K3ogUN5CNol7oomFJeb8Aim@oOW6X!!C&le~ z52VHGkn%_!YE=qpbM_X9@>$0l_OO%5ty;lUt&ZmWOs9bH#i1}&Hy<{teT4%uy|^z= zEy09;OswIJ)dk@DLy?(uX$A21u#D*>J?3ffad>;-I+VV)N2@nMpmHn@E|mE~xKt*w zFV2H=6Vt(N;R@hqZDw{#mBO^)SxnElKrOkB?p{?uGo+J=#jG0ia?>K_QOmf2KFYX$ zmK>)na^>1T_65oL&KP^MQDCuT4fP0_#eLJhK^!Xk;-#BaG(Eh4<^&$&E%Mgl%BL!F z15+6qGsDRVRq{loRxOVHxCmRMHKSBk^hReBv2?4gxP-CzBRW zr-%Foa^v#G;q8VP^0o3c@&B5RPvqP|*j?5L=gjNOP42mtmhyR|ad+ z(OGtMtY;~FX=$S~4lTp4OL}3fg*kVxqKfuaJq9~9^tn#E2HMo|mSXQWRIb(lZ_gja zv1-dHgvnBu>jkv)y9&*JYe_xFCvwjD9jv->e>(Dk0d23iPjo&^0Q*q~NcboX*u372 zdd)&%~Y<)?HSnB|16uM=r%6)ZGQr77AhC%Gz~gwI?}2AGY3 zT^UJie(NN5jF%ANLw3_j|E-LfRTjv}>tZkL!UbIKj04q1H zA+ME9K9I!e$F89=H>+Ted=<48=F!)FAsm#caJ%cBxYCdqERfSBZ>^2ECeO8`bLwHd z?6RCw?EMxCcPindHLI|0;Ai^ms1Y}C^dcI5cM(;2;lPdBeuZ77F_9YQYti5r$z9J& z1M$643ke!~k{Xm+<5=4$aIfiy2Bz-dU$~k5dh|GHsvSvnvnzRFxytx*&O9o}D}fQo zyUEEgzEHa>Sa*E=_j@cc%Q=RdUZG#1lP&54}g!D*Vg=rb)P`^e_Y zbGVYbUgEWNU$_^CxD8=4+y$tjT|faXxIU+9=;xioI9mlpm?&$@jUKa&O2~ziTD=P5 zE_g(Cs!rgNZ{6i7y7i~Q=Vht)LN)4B{EYWwYi}H$V8;b+yutRLxt`f3cbRJ54Z>mP zPEqJzM*QS5ao=TGu6DB|&MlGVCgteT2MXo1%5yE9JvyA0hs=SFkVLp)pNRQZ8`1gg zQ6^wWGcD;`O7|b*(T8CjbWi6=;Kl5QnXx~Z(AGilc+C^>Ja)~LCb1$)G@DZm|Vgo5i)ThJSYE9P@_^?`ZZ|>oN>;94E$_rW8}nCMVEI z;A6a>3hX&~229_bgW&0U_^MAUO<2+fH6gdC{;>+`VN)X5^`@E@u0Bf^4f6n%%X!>S z8($cYv$*@;zJc>&OR^)_o<6-U&4o|U<|gqp!N@fgmYknSrHTyk$CRz?7zG>Lq#8-x zP3|z`m#rq>eXLObfHdl>oyE$t7a&Dq1g*|Kje|XhW5i$;`XI~~{UpW!wr|7RDTT0c zo+3_F-cD=MyKAe8qc7F})=UkqWMN1|I|genB_?-~Xq=zQjCM6a=l#=Q|Je1^^YIxF z_$?v)2J%Zxd{CXSn4Siwy2t)xOxS6!CTK#_k8KFynhUT5~=_>1z8SrA8CRyOiR zwhZ9(Wv{Xp4L3=&i627^kFr-LB*Ep6@d zTeQ=}G5%~CPj1M&;GL9e`p8`yYz=t0_L7Sj+nBwAYjLck4q!S;M!s&C4bU;%aeLckiNB?>xlEV-cO#JGbW9Bl@=x5gPOeQG4TY7=i;AhmpqwB0K z^_xvXqU=N#e08;^vDUqrUe5mfHFFn`%voNQlX}9+YptR$16`?Rd?~JNcBG;` zF4R4LA6_YO5OjX;!>%k^L#!U{C*i9Iy}3J{wyCF#fI6s`8cwS_FO#Q9&v?VX+ll=MJ8jHw$u5Lhvs-a;*O@Im)ETeLno8?B z(#f5aHBjo_hT~EOzH8< zj0_(f;vE1>7uoV-+k=T)tcyZaTuoVSxrJeRFKIdrqK;4Ya!Pvk&t0}?w>AFq z7!3~xpTXhA(zvU>jOeIqh<~%At(>ORs*tS@kC1fFr-C`&LaG-q0JAII$efR-@XWbF z`Z{MNd@$}LyYCe+k59Sbvj}B6Tv3L9_&yI``1TUpx%p-j`Si;QT>bZ;X!Q>G6DQ+F1^IdcEOS={5G7auPJ&o=)Q{^dP5gGb1SOi#ZW3 zaQD?Qn%8JRwTAUVw}2h=V}`y!c^@aP1KTc~z~gDnq@ysC^ex(r<$k+~3V#Ku)s{w& zl&jECI|RRJZh(J&$k6E9x;&|+O6&>cYC+$qefUAJk)|B9;u~LEfzS7+vNEoRX`O~C zj^e+8!M(C@;eB;hU{b_zQugSdHj|4XdusW5svvp4=AkvJ{Xl!DP#a;th8 zIV0bPT~(q@wG~F;qK6;QjZtI=?~p_0a18lSQvk{*wu9t@*{CCX5|x({OuXOyJ|wq! zE#x=d1C4+)aQ^jKwsHSj=1P?*Fojw)-Ljat{;Z?r4TJb@&faA4uhm4k{sNKwM~b|X z)aTzQa^srx2H??mt$1^qzF7C;<_?2hMtbn7!ijnC%MtcE^un0DpG2`}7IUZmD%l+{ zQ=m583tc`7`FZXXowvy0kni7^Xj5AplURU?cOGNGX?wokkX5*Be-eIPJOO8QK42W9 z?&14^ub_ap7#fc*BvCO(@#>THB(*{xzIp2KOZ~H%Nmn<3^J*J9(a8Xe6UKsfnF&p- zXki00t00j)!Tn(_@Gk8LteUWdzT2WfNIj(@MS`*A3>~*+9i0>|&kS*tfptlbUC1VfH#PjiIaaD*SvEO=(xJPNg<-G&&*;hN&9Zl0zEq37@6hAdGPe`I@&+tJ(%;gZ1Qh!BIz??*>=fuP!^L7eeT&{^<`5!(>D~} z-1mml7k}c>hz+ceIS5DI+hRwkG=H-GQMlulOuo}CT{&mN@0WgyoYzt2N9>a3YA24w zw-L#p<)luR4cvzLXI9bvFXGTkZUqM43WY)U$6|upb9T_APN=sPQMuWP_*|Z_DJmNo zrEfRkt(+h9St1RY$~xru*+%rKm7~ThQt?qv6yTV9tiw)exEeYbN1bgXoe2%RpQh(& z!+sHpzLYWg)k$oML5s~w*v16pTnCraD%h`iq(H+X3zL`M#rIA{&^Y-Fdezkvhxm1H z&tf#Q(Rd|8Gg9bF+9|FH-`XfWAMD7llD`ET&PXr?bF%UM$U-v4e;2zT@3c7XLth!A zxi^J*F=KFE6v3-gQfQ$@Do#=_VfpoWyyT~yD8<}?`9+kq9+FEe-xC+HD}g|>7|t}DyMmSf@PTJ(_7qxo8Up`~1j zbnG#qbZ91yInfCRxNNk1JDZemW|_Oa_Mk_uA*w%gBWtRWCcaxk()UTy(BydRTVKf* zt$hnlb@xfaBPU#MeTvPyP%OATKak9n+>Zw1SMy$tO~Drh`^ex=>iDednI?3m(WvPq z!%=PfDY4#XUpp=EH4ep~#uacg+lBvHEtTEa+d$ANb&@>qbB;_Nxsk+LzeDbL33>jx z14cZ`A#&bwY~#C8WOJqibN1F@HuXdi?sM^?qU#p~E3enE?)`R>cU?((b^0&lp7$fZ z)g$Ts24@l+l)83Pz4Wd}~7d;TEfS}()(D)EJbq5GMT-v`wbW8-y|>A93UNLiuCg0pD?;~gk>I^`2w=aTaSe z*5Il3^>FcnAr3NK$CuEWM5G=zVcO&!HhWYuNYyb5IC^y%md!RL^F4CO^fE=UrqWUm zkk1}^WLHT5XgewJ^K;ju+K&ymc*0)nDl#P`NR za;LbQJ)V}pKfZ4u=}`O(KaTGfpL_IImyvIQTiO0zG31F-6bhImQmwjD@TA(@#(B|2 za{iek-LU>MuHUhh{gfq#S0@jozivq3&nFBnG=CvD^&^h>jBa9g)hiO4R8!iV=8qq; zG6WJ!(%G9M4)812Td+ShP7ovie4=kzh%TznFkATuj%rLJW^>K3X-O-qw&pM#oF>x3aY2xg%MyxtBQp0~_?%UwzYw%a+Yu@99(UBm0Du7>**EIu^uBc^|CP zD2D@b8Mrx#N9|6T@VA^8NY$ct3ieJI$()|g@|WE=C5o#i<7DMeWR~|ZLE1fIax+$s zEFFK8fcai_)+bH2VYLX0?S7FBn(Fk*4k=ncLkf+ab+UVgnUgnza`Dbn4V>4hMa=Z$ zKw1AI|H-Y1Ha?GE5+%b#Xi15NhW5T>cv3F(U9ZpYQ>ITkeD|}mPOp?{-cB0qBjYcI6;=Q4bLQRhnSH| z8}e~qJ3+}M{fJDtA3EkcQiu3_)Z3IudQBKa&Tq3tNws+xA2mw6&z=6~E?cYc2?saC zgWymTJLCNX`mlN?8VpwIdShnAcl}z)Cfn4r)eVaD3dUe+@yZ7tSh~OecN& zYLMn{PsBBMqV#LA&H0xk+0Mnq5cfmP}frzrbX_#e%OxF9>ejkHkLiYUG}YA{Ktw z&F}coi@$T$WKgpC&R*?QAsy8h$U}=ycqmhaXgR+nW)GJ_IW~f>xf?Ux=m{wZUBwRH z;YMOFq>}L!3Vi1k(d1D4DqQ?yD6xBdlFYy7j7RuDYO4iwi2q uz{Dvu2^{7NtxXuIYE~P51k~-@kt6dG^_RuXjK1yPmz?^?ue~JdpyKh_CB^`K%VUZQ-O@ zJdp)2L42lg3vn-v)y7!)x}f7Z&N&=nyZMl666 z9UBlb%Qv1A8WAVs8yn;w9~u$p%L$5#i;ne8Fpy!ySuA(SIKg9nD}oq_q330OiHP>& zxD7SRFp};v`er7krluyQ#s-X(yF@^Ipr3^vBV8bD>BMny5Mnt9J6H-1VGe8PEV#;d z;`iGwPlm74XF#Tu& z{g@$oga4(M`)_&^Gb1B&BU1xL{vZ0W1tLPi!}JS;1fBn@`#&#+1;z;k#t#V?{;z<- ze+w8J8=DyDndvbT{t=ic5D@u~fXJ}G-xtFIlLP{aLjp$sE1>jm0X-8lQ$rI*`5%GF z{~rOBpr9OAMwQE`xl8*-266lX{WyNiln{Y&)CI*m?ARH&?({z@= zlmbIBnp}nz^0QHkjkkU`jFeY5abQm%lA{(}o9%IgBEQZ0^{|5CLHkZLLNJB~n%p5LbIoxP9 zr1aB1%v>%rZy2&3A{%y=9%I90ei;UB{|)Lh^SO-OFlaxdWXL#h8OPzq1w%?deZ)9% z8RubW;Skxd#q<~#F5@~3y8Ro}XBKf8_hHauNXdxtv4-hZW?APtd4`#dl^ap}&ojSryZy@1pF^Xip0gf-ms;Z&jn>rRady;747)pK6gk#5Mp>CH9 zrR;S?_4QR~>hi;nMi;zje}r(!MRZH(0km@YFm&09@y=V}XnzpNZ|5T6YXbrUrsCQ7 zQq;6>Maqmegr};~HyuA@9ZZM+3I)2c{WH#=ybPUvx^Osk0f8wg*wSCiz7*y}DuESv zyt@+~Y6>*_p*cmn*Fj%s0y#|l|TmJ@0P6sxfeuI~n{=m&gme|x=fTt2ZSkkl)DKhfpYrGMsU+}@i zzGuJ9zm6LRMCfVgH7t|Tz=X(=FdnSKyge=8y$Pp$gZ0?lF`N3S7>dDaSl5(-`RDv- zaOy))*h3hXsbRb5M`Tw|L+qdGkn{1OPfxF4(;WtD>wmyBT>^z^M%Z&P3y-%2pniWT zx<%LHrSmIzck!`QXBjSdZpS006Th^+Bq!My>=u_^d~i@C3mtWe?A3(pp#+G%6(gD{ zM)TwDKymABq;<22vsRsYoMs|v#(WsNykYm*9zx!dwfOW%n(AdXLt%>~=@`Z#UGRSH zjjO<`#)-7;%VYNb?>QLN$3srJ1cKS;uv*Ln3KuS7=R8U3{QM3zUE8tG-~xgts-thW zCH8e`Q&9IU$c#M$>)gwbQ6GTzx@L@@?SzZsi=be76%#fsrs+TCl5|r6)}}wf6NfFZ z_gIS^2QGqF+lbQ>_>jM=NOylPftJQYG+a%`w;*x4v09nVg@JBPuER`?EPQnM4d*B9 zV;_w^gpnF4=-6@;;(tV8+T~S9zC4D~a;i|Mx&bY%a`bSKKPFhp)5+L!#An}Uzt5;c zOA8@$#9?@fSW-r|JC)wj#k1Ag=>8stGc9E(eKnm*BDgs9WH0zy&mr|)8WInxpe?lq z1EG1)eCLC64&&fgIEu_-+i>E*To{AS+cJjw>#jrf{b~5$jX|l}D_nY<2-VyXR3`qN?Ofr8 z*vY3b@lg?sq!y5zUKtu?jp$Qc4YvGZKu;5->9M~iJ?RNXg{2hjd^4WP8Xh6cKLGi& z4zZOJ8t~O>3lS}`AmpE&PggaYp(CzB4Z_j*mS919n|C6+ z=Lu}@j~amOfl9bHHx@4kb@DJ{B~5QxcF+R%7J9pOLjvf00}(5S?sM;F@g zy@HFf4I(7f^@Oebax~7jN}+VcS@bmYLAC7yP83SP!B7wD-ygxJd^L!V&BWVN=5U(c z1SgX#Y(w+4IJAp!L|KK#rrN_`%3~<~`U4XxpW*6QZ)}p(qi#7R`tsD8BDGs!mYjey zmkvU>N1i_2*2R`9F4Uc7kDHbX5a}+$jcF4g{D%VCPiWFWmoRlaZNcWRilk{2iZvQh z5bYR4>%{m-DXzxsjth7>*Bi4$eud4BDKHGbhK+Nj5F8pqsGdsVkABBGMuK>z8hG6z zNzeHsspM`0csir$L*ZfkxbY4;_pNa)=Lgu#e%$HNLD5nr($1cR6(hbu(*7A%&9lHH zX<0HFH=RxwmO_#W;Zj+P_Q#X)h5Zc2W5!drkS(1*IS8lOY^0|4v&|H^$Uk)!_2=&( zSNs(W=ZVvAZ@N(zsYwA-wD4{go36Qg;q!G}yb6AeD1&9VY^@886Y8+*??>;Kn<((S zf{Q*UV5J)ex0rk+9W#M!lN!cydALhn*jp|K&E@eh=$wl^)^#}SFAdHiY4Yzpi|4yC z@Xk>Q7CK(YTn~sA?uN&QQ|R=rfTfH+l50M|Z`1?aLOOD`Mc~_*lc@W14|-Ltv5MCV z5s|m})^AFr2fC9wnkwH*(elohxaJjv?*lwUJlCbI*OjQTQwYB`_OKn^7o)Wq(3H4< z+Go?~jeiXK|FB2gjH{T-Hz#??5%j>>4u^j)!yH9Nhy>3>d6yocH57B}d!e*pFFLd| zpwQNXUYn`t)Dxo*PHJeX_o6^Q20u1A!yxA>LcU&rj(ai&?{37x*M`*AJ{|X>K0~F_ z1W#^$VP8Jzgwger=}eb83EwcIwIW>*ZaIWUtZ!^{J8#5=e8z@>r|cCI`*7;!6ZU$Q ztr%AE{Q5AWB5euxWrI1Ulf%W@rl+N5lzNdxKak=6$$w4z#uv~4etCDv}SuCX5MMayjVozrzGRN?Jqd4V}rrw z_xMvj4h<)S=+8ZcNORSqoxUZ=vyMl{;!x^|D1@PZ95fQfl0lvd{dw>m2Fz*^%c{cu z>r3!$PdWUL1yk5h(I63v?bl#O+dQh67z4Kp3FuFDLacfn ziK;~7_qBK7c5E`$%6p^pSOMY{IT+tE7q?rSY24BXI6R(A@A?!d`f)a67hULR2m;HHi|Y3gI6PnvHjkl0Dm4&) z7mYuUP}j$B}qZDMHO(@-b>qn0n6G!*Z7inUAQ$XGFqw zc?`B6sGykfk8$*SI(V;sKtpUZENvo5t}PDg12K5*AA%m?4E)*R4fi*KvAtJ-Nxo^= z9>YL(<0-7km`d*FRmk2d2J_Mf(LDYPb%&os!>tT>jJSmFktvvBUkZ(sJm?Ge;6gwN za^ll*-ysO~$7j%-RzdzhO9yvTCFa?$f~RvBB1Q{Q-$psQBbAJi15q#!%tZNv3<%%J z$KIGGeEYH#LOoG%p1U3+q)ITcJqp)dlHe|$33G8il#<#Jo05&)1=g@r<%54a4g0J- zaBf-&s?PF&kqpdgwBfbffE39V??2DPJSkNg8#|E{losGs$ui7-I0p--W#O~x3G8#{ zL(=jbp3mQf(N}`8`DhCc@ojLhb1WV1VjwD>1Fe-jjGM8FI!rfWlIH@Pij70*h#(xC z9e}SIGPL(!xXz>e2UG4*kXk_GJ~~b z``nCbeTtzk$PeX#G`KmK(Dv;?NMku5SNJKWEzP1H!!+>6)?&0Y7groZG40@c*gBNs zO#CHi9$taqZN1o+{Tgbwl2OxpjU;RGG5T~Ws(zhIx`KBw@k=NsmAm7ahhTkVV!A#3D|M$7xEm-!&BWmLa{&9KeS!Ip8Yh2+pn1D?> z;V3(hhz`>X%ubV~CuuM7VQ(QUKJ)Nk+5?mvEWneQk`&l+9Y#75)To|^ytr(^e4mG+ zw@Gkz(xm`{E4bupgqq3{te5qHLQyfg3_J0))r$;PAHh1-Hu8KE4gLWhuD;|W-A9wo zX(!^-f{~rhI2XU>@NrwN1WVYX zNa=V6#(gozy}lH3xL1QHL zuGkz9_@J;D#bIu!dYKIQ!b5O57=|_v!Rew6|^wNolTCL zQn9;Hn%vsk>FwGoOp2^S*bW~OaZbS?FAZ~*vM@5G9hWaA<5p%h9o-y(K9&$!ww6Ko zo*+jf0&(Pd6Nzm&fh?N|G;vE992YUj%jO}rI2PeOQlz?X4U%UR()!#|w0@1mVp>4= z;?yt9hE|_BgBg5R^cTlL zCbIzf*Ba5X{~o^AmmpxLE7i6M>U7#Ps`D2m9qwqFm5@)a`bJdBFpyri9g7+YsCN|) zum329ommlF!ffecy$->s5}gmSV4YTh;*H~I*7;CuJC%y*nk#YEEf1}(+1NO?0LwRf zAhtdl*B=QHt6z|deOwqN^N|*)NUxUmfbUie*-&Ge9r zXjT%MBlTe`K20qKD=8gw8;a2v8US6M3F#(QP`9B!c3I58NBCZrf)lukfD#x)^ zX{dkQg88Ro$@)Y!!m{h&B3Fo#m)c}waSmL7eDfbn5f)p|JqI9V0L>c_TI*(Zi&+}$?T??8up5WDq zqQ)_U9M8?@7Z~8K}gNF)MpH#+Vy+o`` z&V^}WG;|*DDN?wRGDAjS*?VnpLv!GF#tf_S>Y(s2w7i*w`MWBREf_byHLIa=PZJXPMVQzZ zg@|S?_)m35cC$Hr+l#TUwG_`DC&SKJ6&*LNFfyVD?w4e7rON@$C+s1q&PLr=A56Xa zoqajl4G{)w5Vk)TJvaF{<&+PNViP<|4u{th!Fu_p4<1=9z@YdP>>5Q_$|=UJ7y8(> z+yG&#ig8BZgW5YP=zU>KE&i_foFWIc%gRW3F9T;r7Yo<7la$E|dUyav#G8pvTfG?qPFrC@Nj$+9| zNO=~YFA=KQ4&_RXG3XT z7TN`w_VFtZ#Z5_gUdxAmN-i!WC?a897C!jKLVDXuv_39G@UwVWZRcZ}HXkMR-new2 z9Bm`xfK!TEu+`YxR|1}EI(D~b;Zfj9oNUcS(bilXD9eX(M=5rm%!gW79wPK| zvBWzYJ+JHt_Tu^gi==E2l;1xAF0U~Xvz z0)qKqf2zhJwK%+C@^N5mH99U9;`SeF&=in|2iKm{#>EBL^)d)@zvLlRDGfzu7K0@o zD%c|_hSVz%mENt|g=O6%SIHRXCLBi`)EM^fgzbHL6PRxl@jQ z!{w0PpA2JjE<^&?plkUCK^7)rurM0j2nI8v6R=U>)3#An;9o7myO|t>TrNT9m2CV# z0mk&_A^vJPx(?=pua^bgEFMIoim+@=1vU!yKN5Fl;m;kp@K20DiC}#d+!T)c9>Fkw zU<21N(U{K8N2zu^+DwYEM!pEPp(S8?+;FrYA7vIrc)f2ub+Do^uEZZ!t(B+~U5yKR zJY;;2MCF$p%$och;ty+~C0m2FWfHWw=^b8vEcHkuzKK;EJP za-UbgtttrLxWO=azY49pbMQ=B;NiUyX#BVa><@9c=+DPO`BF69%7m)O3Y5QKu$j+; zvwtB1Ywd9BbvR@OGH_hO9nJS+A<<`m%lAv5VVnbrqk=J+l!ddRm9UBR!l-o|s9A8( PyEGODb&D{1dnW!5&PYuU literal 0 HcmV?d00001 diff --git a/network/AutoEncoder.py b/network/AutoEncoder.py new file mode 100644 index 0000000..b2de7ea --- /dev/null +++ b/network/AutoEncoder.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn + + +if __name__ == "__main__": + from encoder import Encoder + from generator import Generator +else: + from .encoder import Encoder + from .generator import Generator + + +class AutoEncoder(nn.Module): + def __init__(self, image_dims, batch_size, C=20, activation='relu', n_residual_blocks=8, channel_norm=True): + super(AutoEncoder, self).__init__() + + self.encoder = Encoder(image_dims, batch_size, activation, C, channel_norm) + + # 计算encoder输出的维度 + with torch.no_grad(): + dummy_input = torch.zeros(1, *image_dims) + encoder_output = self.encoder(dummy_input) + encoder_output_dims = encoder_output.shape[1:] + + self.generator = Generator(encoder_output_dims, batch_size, C, activation, n_residual_blocks, channel_norm) + + def forward(self, x): + encoded = self.encoder(x) + decoded = self.generator(encoded) + return decoded + +# 测试代码 +if __name__ == "__main__": + batch_size = 32 + image_dims = (3, 96, 96) # (channels, height, width) + + model = AutoEncoder(image_dims, batch_size) + + # 创建一个随机输入张量来测试模型 + dummy_input = torch.randn(batch_size, *image_dims) + + output = model(dummy_input) + print(f"Input shape: {dummy_input.shape}") + print(f"Output shape: {output.shape}") + + # 确保输入和输出的形状相同 + assert dummy_input.shape == output.shape, "Input and output shapes do not match!" + print("AutoEncoder test passed successfully!") \ No newline at end of file diff --git a/network/__init__.py b/network/__init__.py new file mode 100644 index 0000000..d4d788a --- /dev/null +++ b/network/__init__.py @@ -0,0 +1 @@ +# Model / loss definitions diff --git a/network/channel.py b/network/channel.py new file mode 100644 index 0000000..98a8eee --- /dev/null +++ b/network/channel.py @@ -0,0 +1,59 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import Parameter + +def InstanceNorm2D_wrap(input_channels, momentum=0.1, affine=True, + track_running_stats=False, **kwargs): + """ + Wrapper around default Torch instancenorm + """ + instance_norm_layer = nn.InstanceNorm2d(input_channels, + momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + return instance_norm_layer + +def ChannelNorm2D_wrap(input_channels, momentum=0.1, affine=True, + track_running_stats=False, **kwargs): + """ + Wrapper around Channel Norm module + """ + channel_norm_layer = ChannelNorm2D(input_channels, + momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + return channel_norm_layer + +class ChannelNorm2D(nn.Module): + """ + Similar to default Torch instanceNorm2D but calculates + moments over channel dimension instead of spatial dims. + Expects input_dim in format (B,C,H,W) + """ + + def __init__(self, input_channels, momentum=0.1, eps=1e-3, + affine=True, **kwargs): + super(ChannelNorm2D, self).__init__() + + self.momentum = momentum + self.eps = eps + self.affine = affine + + if affine is True: + self.gamma = nn.Parameter(torch.ones(1, input_channels, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, input_channels, 1, 1)) + + def forward(self, x): + """ + Calculate moments over channel dim, normalize. + x: Image tensor, shape (B,C,H,W) + """ + mu, var = torch.mean(x, dim=1, keepdim=True), torch.var(x, dim=1, keepdim=True) + + x_normed = (x - mu) * torch.rsqrt(var + self.eps) + + if self.affine is True: + x_normed = self.gamma * x_normed + self.beta + return x_normed diff --git a/network/discriminator.py b/network/discriminator.py new file mode 100644 index 0000000..6dd29c9 --- /dev/null +++ b/network/discriminator.py @@ -0,0 +1,95 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class Discriminator(nn.Module): + def __init__(self, image_dims, context_dims, C, spectral_norm=True): + """ + Convolutional patchGAN discriminator used in [1]. + Accepts as input generator output G(z) or x ~ p*(x) where + p*(x) is the true data distribution. + Contextual information provided is encoder output y = E(x) + ======== + Arguments: + image_dims: Dimensions of input image, (C_in,H,W) + context_dims: Dimensions of contextual information, (C_in', H', W') + C: Bottleneck depth, controls bits-per-pixel + C = 220 used in [1], C = C_in' if encoder output used + as context. + + [1] Mentzer et. al., "High-Fidelity Generative Image Compression", + arXiv:2006.09965 (2020). + """ + super(Discriminator, self).__init__() + + self.image_dims = image_dims + self.context_dims = context_dims + im_channels = self.image_dims[0] + kernel_dim = 4 + context_C_out = 12 + filters = (64, 128, 256, 512) + + # Upscale encoder output - (C, 16, 16) -> (12, 256, 256) + self.context_conv = nn.Conv2d(C, context_C_out, kernel_size=3, padding=1, padding_mode='reflect') + self.context_upsample = nn.Upsample(scale_factor=16, mode='nearest') + + # Images downscaled to 500 x 1000 + randomly cropped to 256 x 256 + # assert image_dims == (im_channels, 256, 256), 'Crop image to 256 x 256!' + + # Layer / normalization options + # TODO: calculate padding properly + cnn_kwargs = dict(stride=2, padding=1, padding_mode='reflect') + self.activation = nn.LeakyReLU(negative_slope=0.2) + + if spectral_norm is True: + norm = nn.utils.spectral_norm + else: + norm = nn.utils.weight_norm + + # (C_in + C_in', 256,256) -> (64,128,128), with implicit padding + # TODO: Check if removing spectral norm in first layer works + self.conv1 = norm(nn.Conv2d(im_channels + context_C_out, filters[0], kernel_dim, **cnn_kwargs)) + + # (128,128) -> (64,64) + self.conv2 = norm(nn.Conv2d(filters[0], filters[1], kernel_dim, **cnn_kwargs)) + + # (64,64) -> (32,32) + self.conv3 = norm(nn.Conv2d(filters[1], filters[2], kernel_dim, **cnn_kwargs)) + + # (32,32) -> (16,16) + self.conv4 = norm(nn.Conv2d(filters[2], filters[3], kernel_dim, **cnn_kwargs)) + + self.conv_out = nn.Conv2d(filters[3], 1, kernel_size=1, stride=1) + + def forward(self, x, y): + """ + x: Concatenated real/gen images + y: Quantized latents + """ + batch_size = x.size()[0] + + # Concatenate upscaled encoder output y as contextual information + y = self.activation(self.context_conv(y)) + y = self.context_upsample(y) + + x = torch.cat((x,y), dim=1) + x = self.activation(self.conv1(x)) + x = self.activation(self.conv2(x)) + x = self.activation(self.conv3(x)) + x = self.activation(self.conv4(x)) + + out_logits = self.conv_out(x).view(-1,1) + out = torch.sigmoid(out_logits) + + return out, out_logits + +if __name__ == "__main__": + B = 2 + C = 7 + print('Image 1') + x = torch.randn((B,3,256,256)) + x_dims = tuple(x.size()) + D = Discriminator(image_dims=x_dims[1:], context_dims=tuple(x.size())[1:], C=C) + print('Discriminator output', x.size()) diff --git a/network/encoder.py b/network/encoder.py new file mode 100644 index 0000000..df526eb --- /dev/null +++ b/network/encoder.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +if __name__ == "__main__": + import channel, instance +else: + from . import channel, instance + + + +class Encoder(nn.Module): + def __init__(self, image_dims, batch_size, activation='relu', C=220, + channel_norm=True): + """ + Encoder with convolutional architecture proposed in [1]. + Projects image x ([C_in,256,256]) into a feature map of size C x W/16 x H/16 + ======== + Arguments: + image_dims: Dimensions of input image, (C_in,H,W) + batch_size: Number of instances per minibatch + C: Bottleneck depth, controls bits-per-pixel + C = {2,4,8,16} + + [1] Mentzer et. al., "High-Fidelity Generative Image Compression", + arXiv:2006.09965 (2020). + """ + + super(Encoder, self).__init__() + + kernel_dim = 3 + filters = (15, 30, 60, 120, 240) + + # Images downscaled to 500 x 1000 + randomly cropped to 256 x 256 + im_channels = image_dims[0] + # assert image_dims == (im_channels, 256, 256), 'Crop image to 256 x 256!' + + # Layer / normalization options + cnn_kwargs = dict(stride=2, padding=0, padding_mode='reflect') + norm_kwargs = dict(momentum=0.1, affine=True, track_running_stats=False) + activation_d = dict(relu='ReLU', elu='ELU', leaky_relu='LeakyReLU') + self.activation = getattr(nn, activation_d[activation]) # (leaky_relu, relu, elu) + self.n_downsampling_layers = 4 + + if channel_norm is True: + self.interlayer_norm = channel.ChannelNorm2D_wrap + else: + self.interlayer_norm = instance.InstanceNorm2D_wrap + + self.pre_pad = nn.ReflectionPad2d(3) + self.asymmetric_pad = nn.ReflectionPad2d((0,1,1,0)) # Slower than tensorflow? + self.post_pad = nn.ReflectionPad2d(1) + + heights = [2**i for i in range(4,9)][::-1] + widths = heights + H1, H2, H3, H4, H5 = heights + W1, W2, W3, W4, W5 = widths + + # (256,256) -> (256,256), with implicit padding + self.conv_block1 = nn.Sequential( + self.pre_pad, + nn.Conv2d(im_channels, filters[0], kernel_size=(7,7), stride=1), + self.interlayer_norm(filters[0], **norm_kwargs), + self.activation(), + ) + + # (256,256) -> (128,128) + self.conv_block2 = nn.Sequential( + self.asymmetric_pad, + nn.Conv2d(filters[0], filters[1], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[1], **norm_kwargs), + self.activation(), + ) + + # (128,128) -> (64,64) + self.conv_block3 = nn.Sequential( + self.asymmetric_pad, + nn.Conv2d(filters[1], filters[2], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[2], **norm_kwargs), + self.activation(), + ) + + # (64,64) -> (32,32) + self.conv_block4 = nn.Sequential( + self.asymmetric_pad, + nn.Conv2d(filters[2], filters[3], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[3], **norm_kwargs), + self.activation(), + ) + + # (32,32) -> (16,16) + self.conv_block5 = nn.Sequential( + self.asymmetric_pad, + nn.Conv2d(filters[3], filters[4], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[4], **norm_kwargs), + self.activation(), + ) + + # Project channels onto space w/ dimension C + # Feature maps have dimension C x W/16 x H/16 + # (16,16) -> (16,16) + self.conv_block_out = nn.Sequential( + self.post_pad, + nn.Conv2d(filters[4], C, kernel_dim, stride=1), + ) + + + def forward(self, x): + x = self.conv_block1(x) + x = self.conv_block2(x) + x = self.conv_block3(x) + x = self.conv_block4(x) + x = self.conv_block5(x) + out = self.conv_block_out(x) + return out + + +if __name__ == "__main__": + B = 32 + C = 220 + print('Image 1') + x = torch.randn((B,3,96,96)) + x_dims = tuple(x.size()) + E = Encoder(image_dims=x_dims[1:], batch_size=B, C=C) + print(E(x).size()) + diff --git a/network/generator.py b/network/generator.py new file mode 100644 index 0000000..ec56434 --- /dev/null +++ b/network/generator.py @@ -0,0 +1,195 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +from torch.nn import GELU + + +if __name__ == "__main__": + import channel, instance +else: + from . import channel, instance + +class ResidualBlock(nn.Module): + def __init__(self, input_dims, kernel_size=3, stride=1, + channel_norm=True, activation='relu'): + """ + input_dims: Dimension of input tensor (B,C,H,W) + """ + super(ResidualBlock, self).__init__() + + self.activation = getattr(F, activation) + in_channels = input_dims[1] + norm_kwargs = dict(momentum=0.1, affine=True, track_running_stats=False) + + if channel_norm is True: + self.interlayer_norm = channel.ChannelNorm2D_wrap + else: + self.interlayer_norm = instance.InstanceNorm2D_wrap + + pad_size = int((kernel_size-1)/2) + self.pad = nn.ReflectionPad2d(pad_size) + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride) + self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride) + self.norm1 = self.interlayer_norm(in_channels, **norm_kwargs) + self.norm2 = self.interlayer_norm(in_channels, **norm_kwargs) + + def forward(self, x): + identity_map = x + res = self.pad(x) + res = self.conv1(res) + res = self.norm1(res) + res = self.activation(res) + + res = self.pad(res) + res = self.conv2(res) + res = self.norm2(res) + + return torch.add(res, identity_map) + +class Generator(nn.Module): + def __init__(self, input_dims, batch_size, C=16, activation='relu', + n_residual_blocks=8, channel_norm=True, sample_noise=False, + noise_dim=32): + + """ + Generator with convolutional architecture proposed in [1]. + Upscales quantized encoder output into feature map of size C x W x H. + Expects input size (C,16,16) + ======== + Arguments: + input_dims: Dimensions of quantized representation, (C,H,W) + batch_size: Number of instances per minibatch + C: Encoder bottleneck depth, controls bits-per-pixel + C = 220 used in [1]. + + [1] Mentzer et. al., "High-Fidelity Generative Image Compression", + arXiv:2006.09965 (2020). + """ + + super(Generator, self).__init__() + + kernel_dim = 3 + filters = [240, 120, 60,30,15] + self.n_residual_blocks = n_residual_blocks + self.sample_noise = sample_noise + self.noise_dim = noise_dim + + # Layer / normalization options + cnn_kwargs = dict(stride=2, padding=1, output_padding=1) + norm_kwargs = dict(momentum=0.1, affine=True, track_running_stats=False) + activation_d = dict(relu='ReLU', elu='ELU', leaky_relu='LeakyReLU', gelu='GELU') + self.activation = getattr(nn, activation_d[activation]) # (leaky_relu, relu, elu, gelu) + self.n_upsampling_layers = 4 + + if channel_norm is True: + self.interlayer_norm = channel.ChannelNorm2D_wrap + else: + self.interlayer_norm = instance.InstanceNorm2D_wrap + + self.pre_pad = nn.ReflectionPad2d(1) + self.asymmetric_pad = nn.ReflectionPad2d((0,1,1,0)) # Slower than tensorflow? + self.post_pad = nn.ReflectionPad2d(3) + + H0, W0 = input_dims[1:] + heights = [2**i for i in range(5,9)] + widths = heights + H1, H2, H3, H4 = heights + W1, W2, W3, W4 = widths + + + # (16,16) -> (16,16), with implicit padding + self.conv_block_init = nn.Sequential( + self.interlayer_norm(C, **norm_kwargs), + self.pre_pad, + nn.Conv2d(C, filters[0], kernel_size=(3,3), stride=1), + self.interlayer_norm(filters[0], **norm_kwargs), + ) + + if sample_noise is True: + # Concat noise with latent representation + filters[0] += self.noise_dim + + for m in range(n_residual_blocks): + resblock_m = ResidualBlock(input_dims=(batch_size, filters[0], H0, W0), + channel_norm=channel_norm, activation=activation) + self.add_module(f'resblock_{str(m)}', resblock_m) + + # (16,16) -> (32,32) + self.upconv_block1 = nn.Sequential( + nn.ConvTranspose2d(filters[0], filters[1], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[1], **norm_kwargs), + self.activation(), + ) + + self.upconv_block2 = nn.Sequential( + nn.ConvTranspose2d(filters[1], filters[2], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[2], **norm_kwargs), + self.activation(), + ) + + self.upconv_block3 = nn.Sequential( + nn.ConvTranspose2d(filters[2], filters[3], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[3], **norm_kwargs), + self.activation(), + ) + + self.upconv_block4 = nn.Sequential( + nn.ConvTranspose2d(filters[3], filters[4], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[4], **norm_kwargs), + self.activation(), + ) + + self.conv_block_out = nn.Sequential( + self.post_pad, + nn.Conv2d(filters[-1], 3, kernel_size=(7,7), stride=1), + ) + + + def forward(self, x): + + head = self.conv_block_init(x) + + if self.sample_noise is True: + B, C, H, W = tuple(head.size()) + z = torch.randn((B, self.noise_dim, H, W)).to(head) + head = torch.cat((head,z), dim=1) + + for m in range(self.n_residual_blocks): + resblock_m = getattr(self, f'resblock_{str(m)}') + if m == 0: + x = resblock_m(head) + else: + x = resblock_m(x) + + x += head + x = self.upconv_block1(x) + x = self.upconv_block2(x) + x = self.upconv_block3(x) + x = self.upconv_block4(x) + out = self.conv_block_out(x) + + return out + + +if __name__ == "__main__": + import os + + C = 128 + y = torch.randn([10,C,2,2]) + y_dims = y.size() + G = Generator(y_dims[1:], y_dims[0], C=C, n_residual_blocks=3, sample_noise=True) + + x_hat = G(y) + print(f"Output size: {x_hat.size()}") + + # Save the model + torch.save(G.state_dict(), "generator_model.pth") + + # Get the file size + file_size = os.path.getsize("generator_model.pth") + print(f"Generator model size on disk: {file_size / 1024:.2f} KB") + + # Optionally, remove the file after checking its size + os.remove("generator_model.pth") \ No newline at end of file diff --git a/network/hyper.py b/network/hyper.py new file mode 100644 index 0000000..82f0851 --- /dev/null +++ b/network/hyper.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.helpers import maths +lower_bound_toward = maths.LowerBoundToward.apply + +def get_num_DLMM_channels(C, K=4, params=['mu','scale','mix']): + """ + C: Channels of latent representation (L3C uses 5). + K: Number of mixture coefficients. + """ + return C * K * len(params) + +def get_num_mixtures(K_agg, C, params=['mu','scale','mix']): + return K_agg // (len(params) * C) + +def unpack_likelihood_params(x, conv_out, log_scales_min): + + N, C, H, W = x.shape + K_agg = conv_out.shape[1] + + K = get_num_mixtures(K_agg, C) + + # For each channel: K pi / K mu / K sigma + conv_out = conv_out.reshape(N, 3, C, K, H, W) + logit_pis = conv_out[:, 0, ...] + means = conv_out[:, 1, ...] + log_scales = conv_out[:, 2, ...] + log_scales = lower_bound_toward(log_scales, log_scales_min) + x = x.reshape(N, C, 1, H, W) + + return x, (logit_pis, means, log_scales), K + + +class HyperpriorAnalysis(nn.Module): + """ + Hyperprior 'analysis model' as proposed in [1]. + + [1] Ballé et. al., "Variational image compression with a scale hyperprior", + arXiv:1802.01436 (2018). + + C: Number of input channels + """ + def __init__(self, C=220, N=320, activation='relu'): + super(HyperpriorAnalysis, self).__init__() + + cnn_kwargs = dict(kernel_size=5, stride=2, padding=2, padding_mode='reflect') + self.activation = getattr(F, activation) + self.n_downsampling_layers = 2 + + self.conv1 = nn.Conv2d(C, N, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(N, N, **cnn_kwargs) + self.conv3 = nn.Conv2d(N, N, **cnn_kwargs) + + def forward(self, x): + + # x = torch.abs(x) + x = self.activation(self.conv1(x)) + x = self.activation(self.conv2(x)) + x = self.conv3(x) + + return x + + +class HyperpriorSynthesis(nn.Module): + """ + Hyperprior 'synthesis model' as proposed in [1]. Outputs + distribution parameters of input latents. + + [1] Ballé et. al., "Variational image compression with a scale hyperprior", + arXiv:1802.01436 (2018). + + C: Number of output channels + """ + def __init__(self, C=220, N=320, activation='relu', final_activation=None): + super(HyperpriorSynthesis, self).__init__() + + cnn_kwargs = dict(kernel_size=5, stride=2, padding=2, output_padding=1) + self.activation = getattr(F, activation) + self.final_activation = final_activation + + self.conv1 = nn.ConvTranspose2d(N, N, **cnn_kwargs) + self.conv2 = nn.ConvTranspose2d(N, N, **cnn_kwargs) + self.conv3 = nn.ConvTranspose2d(N, C, kernel_size=3, stride=1, padding=1) + + if self.final_activation is not None: + self.final_activation = getattr(F, final_activation) + + def forward(self, x): + x = self.activation(self.conv1(x)) + x = self.activation(self.conv2(x)) + x = self.conv3(x) + + if self.final_activation is not None: + x = self.final_activation(x) + return x + + +class HyperpriorSynthesisDLMM(nn.Module): + """ + Outputs distribution parameters of input latents, conditional on + hyperlatents, assuming a discrete logistic mixture model. + + C: Number of output channels + """ + def __init__(self, C=64, N=320, activation='relu', final_activation=None): + super(HyperpriorSynthesisDLMM, self).__init__() + + cnn_kwargs = dict(kernel_size=5, stride=2, padding=2, output_padding=1) + self.activation = getattr(F, activation) + self.final_activation = final_activation + + self.conv1 = nn.ConvTranspose2d(N, N, **cnn_kwargs) + self.conv2 = nn.ConvTranspose2d(N, N, **cnn_kwargs) + self.conv3 = nn.ConvTranspose2d(N, C, kernel_size=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(C, get_num_DLMM_channels(C), kernel_size=1, stride=1) + + if self.final_activation is not None: + self.final_activation = getattr(F, final_activation) + + def forward(self, x): + x = self.activation(self.conv1(x)) + x = self.activation(self.conv2(x)) + x = self.conv3(x) + x = self.conv_out(x) + + if self.final_activation is not None: + x = self.final_activation(x) + return x diff --git a/network/instance.py b/network/instance.py new file mode 100644 index 0000000..37595c9 --- /dev/null +++ b/network/instance.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import Parameter + +def InstanceNorm2D_wrap(input_channels, momentum=0.1, affine=True, + track_running_stats=False, **kwargs): + """ + Wrapper around default Torch instancenorm + """ + instance_norm_layer = nn.InstanceNorm2d(input_channels, + momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + return instance_norm_layer