diff --git a/.gitignore b/.gitignore index ed07724..da61b35 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,8 @@ __pycache__ *.json *.bin *.pth +*.txt +STL10Data +!loss/perceptual_similarity/weights/v0.1/vgg.pth +!loss/perceptual_similarity/weights/v0.1/squeeze.pth +!loss/perceptual_similarity/weights/v0.1/alex.pth \ No newline at end of file diff --git a/DualConvMixer.py b/DualConvMixer.py new file mode 100644 index 0000000..e40fcd5 --- /dev/null +++ b/DualConvMixer.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# class Residual(nn.Module): +# def __init__(self, fn): +# super().__init__() +# self.fn = fn + +# def forward(self, x): +# return self.fn(x) + x + + +# class ConvBlock(nn.Module): +# def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): +# super(ConvBlock, self).__init__() +# self.conv = nn.Sequential( +# Residual(nn.Sequential( +# nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding), +# nn.GELU(), +# nn.BatchNorm2d(in_channels) +# )), +# nn.Conv2d(in_channels, out_channels, kernel_size=1), +# nn.GELU(), +# nn.BatchNorm2d(out_channels) +# ) + +# def forward(self, x): +# return self.conv(x) + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): + super(ConvBlock, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + +class UpBlock(nn.Module): + def __init__(self, in_channels, out_channels,kernel_size=3, stride=1, padding=1): + super(UpBlock, self).__init__() + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + # self.upsample = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False) + self.conv = ConvBlock(in_channels, out_channels,kernel_size=kernel_size, stride=stride, padding=padding) + + def forward(self, x): + x = self.upsample(x) + return self.conv(x) + +# class DeconvBlock(nn.Module): +# def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, output_padding=0): +# super(DeconvBlock, self).__init__() +# self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding) +# self.bn = nn.BatchNorm2d(out_channels) +# self.relu = nn.ReLU(inplace=True) + +# def forward(self, x): +# return self.relu(self.bn(self.deconv(x))) + +class HourglassModel(nn.Module): + def __init__(self, input_channels=3, latent_dim=128): + super(HourglassModel, self).__init__() + + # Encoder (正卷积部分) + self.encoder = nn.Sequential( + ConvBlock(input_channels, 64), + ConvBlock(64, 64), + nn.MaxPool2d(2), + ConvBlock(64, 128), + ConvBlock(128, 128), + nn.MaxPool2d(2), + ConvBlock(128, 256), + ConvBlock(256, 256), + nn.MaxPool2d(2), + ConvBlock(256, 512), + ConvBlock(512, 512), + nn.MaxPool2d(2), + ConvBlock(512, latent_dim) + ) + + # Decoder + self.decoder = nn.Sequential( + UpBlock(latent_dim, 512), + UpBlock(512, 256), + UpBlock(256, 128), + UpBlock(128, 64), + nn.Conv2d(64, input_channels, kernel_size=3, stride=1, padding=1), + nn.Tanh() + ) + + # # Decoder (反卷积部分) + # self.decoder = nn.Sequential( + # DeconvBlock(latent_dim, 512, kernel_size=3, stride=2, padding=1, output_padding=1), + # DeconvBlock(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), + # DeconvBlock(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), + # DeconvBlock(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), + # nn.Conv2d(64, input_channels, kernel_size=3, stride=1, padding=1), + # nn.Tanh() + # ) + + def forward(self, x): + latent = self.encoder(x) + output = self.decoder(latent) + return output + + def encode(self, x): + return self.encoder(x) + + def decode(self, latent): + return self.decoder(latent) + +# 测试模型 +if __name__ == "__main__": + # 创建一个示例输入 + batch_size = 1 + channels = 3 + height, width = 32, 32 + x = torch.randn(batch_size, channels, height, width) + + # 初始化模型 + model = HourglassModel() + + # 前向传播 + output = model(x) + + print(model.encode(x).shape) + print(f"Input shape: {x.shape}") + print(f"Output shape: {output.shape}") + # print(f"Model architecture:\n{model}") \ No newline at end of file diff --git a/DualTrain.py b/DualTrain.py new file mode 100644 index 0000000..512b59d --- /dev/null +++ b/DualTrain.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +import os +from tqdm import tqdm +import matplotlib.pyplot as plt +from DualConvMixer import HourglassModel +from torchvision.models import vgg16 +from torchvision.transforms.functional import to_pil_image +from torch.optim.lr_scheduler import OneCycleLR +from network.AutoEncoder import AutoEncoder +import numpy as np +import torch.nn.functional as F +from torchvision.models import vgg16 +from loss.perceptual_similarity.perceptual_loss import PerceptualLoss +import kornia +import colour +import numpy as np + +# 设置设备 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# 创建保存结果的文件夹 +os.makedirs("results", exist_ok=True) +os.makedirs("checkpoints", exist_ok=True) + +# 数据预处理 + +transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Lambda(lambda x: kornia.color.rgb_to_lab(x.unsqueeze(0)).squeeze(0)) +]) + +batch_size = 128 +# 加载STL10数据集 +train_dataset = datasets.STL10(root='./STL10Data', split='train+unlabeled', download=True, transform=transform) +train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8) + +# 加载CIFAR-10数据集 +# train_dataset = datasets.CIFAR10(root='./CIFAR10RawData', train=True, download=True, transform=transform) +# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2) + +# 初始化模型 +# model = HourglassModel().to(device) +model = AutoEncoder(image_dims=(3, 256, 256), batch_size=batch_size, + C=16,activation='leaky_relu').to(device) + +perceptual_loss = PerceptualLoss(model='net-lin', net='alex', + colorspace='Lab', use_gpu=torch.cuda.is_available()).to(device) +# 优化器 +optimizer = optim.AdamW(model.parameters(), lr=0.001) + +import torch.nn.functional as F +from kornia.losses import ssim_loss #psnr_loss + +def lab_loss(output, target): + # Calculate PSNR loss + # psnr = psnr_loss(output, target, max_val=255.0) # Assuming 8-bit color depth + # psnr_loss_value = 1.0 - psnr / 100.0 # Convert PSNR to a loss (0 to 1 range) + + # Separate channel losses + l_loss = F.mse_loss(output[:, 0], target[:, 0]) / (100.0 ** 2) # Normalize by square of range + a_loss = F.mse_loss(output[:, 1], target[:, 1]) / (255.0 ** 2) # Normalize by square of range + b_loss = F.mse_loss(output[:, 2], target[:, 2]) / (255.0 ** 2) # Normalize by square of range + + # SSIM loss (applied on L channel) + ssim = ssim_loss(output[:, 0].unsqueeze(1), target[:, 0].unsqueeze(1), window_size=11) + + # Total Variation loss (you might want to normalize this too) + tv_loss = (torch.mean(torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:])) + + torch.mean(torch.abs(output[:, :, :-1, :] - output[:, :, 1:, :]))) / 255.0 + + high_freq_output = output - F.avg_pool2d(output, kernel_size=3, stride=1, padding=1) + high_freq_target = target - F.avg_pool2d(target, kernel_size=3, stride=1, padding=1) + high_freq_loss = F.mse_loss(high_freq_output, high_freq_target) + + # Combine losses (adjust weights as needed) + total_loss = l_loss + 0.5 * (a_loss + b_loss) + 0.1 * ssim + 0.01 * tv_loss+ 0.1 * high_freq_loss + return total_loss, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss # Return individual losses for monitoring + # total_loss = 0.5 * psnr_loss_value + 0.2 * (l_loss + 0.5 * (a_loss + b_loss)) + 0.1 * ssim + 0.01 * tv_loss + # return total_loss, psnr_loss_value, l_loss, a_loss, b_loss, ssim, tv_loss + + +# 训练函数 +def train(epoch): + model.train() + total_loss = 0 + for batch_idx, (data, _) in enumerate(tqdm(train_loader)): + data = data.to(device) + optimizer.zero_grad() + output = model(data) + + # Calculate perceptual loss and take the mean to get a scalar + loss_perceptual = perceptual_loss(output, data).mean() + + # MSE loss is already a scalar + # loss_mse = nn.MSELoss()(output, data) + + # Combine losses (you can adjust the weights) + # loss =loss_perceptual# + 0.5 * loss_mse + # LAB-specific loss + loss_lab, l_loss, a_loss, b_loss, ssim, tv_loss, high_freq_loss = lab_loss(output, data) + + # Combine losses + loss = 0.3 * loss_perceptual + 0.7 * loss_lab + + loss.backward() + optimizer.step() + + total_loss += loss.item() + + print(f'Epoch {epoch}, Batch {batch_idx}, Total Loss: {loss.item():.4f}, ' + f'DE: , L: {l_loss.item():.4f}, ' + f'a: {a_loss.item():.4f}, b: {b_loss.item():.4f}, ' + f'SSIM: {ssim.item():.4f}, TV: {tv_loss.item():.4f}, HF: {high_freq_loss.item():.4f}') + print(f'pLoss: {loss_perceptual.item():.4f}', + f'lr: {optimizer.param_groups[0]["lr"]:.8f}') + return total_loss / len(train_loader) + + + +# 保存图像对比 +def save_image_comparison(epoch, data, output): + fig, axes = plt.subplots(2, 5, figsize=(15, 6)) + for i in range(5): + # 在保存图像对比函数中 + axes[0, i].imshow(to_pil_image(kornia.color.lab_to_rgb(data[i].unsqueeze(0)).squeeze(0).cpu())) + axes[0, i].axis('off') + axes[0, i].set_title('Original') + axes[1, i].imshow(to_pil_image(kornia.color.lab_to_rgb(output[i].unsqueeze(0)).squeeze(0).cpu().clamp(0, 1))) + + # axes[1, i].imshow(to_pil_image(output[i].cpu().clamp(-1, 1))) + axes[1, i].axis('off') + axes[1, i].set_title('Reconstructed') + + plt.tight_layout() + plt.savefig(f'results/256*256LAB++stl10_epoch_{epoch}.png') + plt.close() + + +# 训练循环 +num_epochs = 50 + +for epoch in range(1, num_epochs + 1): + avg_loss= train(epoch) + print(f'Epoch {epoch}, Average Loss: {avg_loss:.4f}') + + # 保存模型检查点 + if epoch % 5 == 0: + torch.save(model.state_dict(), f'checkpoints/model_epoch_{epoch}.pth') + + # 生成并保存图像对比 + if epoch % 1 == 0: + model.eval() + with torch.no_grad(): + data = next(iter(train_loader))[0][:5].to(device) + output = model(data) + save_image_comparison(epoch, data, output) + +print("Training completed!") \ No newline at end of file diff --git a/loss/__init__.py b/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/loss/losses.py b/loss/losses.py new file mode 100644 index 0000000..1c918ba --- /dev/null +++ b/loss/losses.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from src.helpers.utils import get_scheduled_params + +def weighted_rate_loss(config, total_nbpp, total_qbpp, step_counter, ignore_schedule=False): + """ + Heavily penalize the rate with weight lambda_A >> lambda_B if it exceeds + some target r_t, otherwise penalize with lambda_B + """ + lambda_A = get_scheduled_params(config.lambda_A, config.lambda_schedule, step_counter, ignore_schedule) + lambda_B = get_scheduled_params(config.lambda_B, config.lambda_schedule, step_counter, ignore_schedule) + + assert lambda_A > lambda_B, "Expected lambda_A > lambda_B, got (A) {} <= (B) {}".format( + lambda_A, lambda_B) + + target_bpp = get_scheduled_params(config.target_rate, config.target_schedule, step_counter, ignore_schedule) + + total_qbpp = total_qbpp.item() + if total_qbpp > target_bpp: + rate_penalty = lambda_A + else: + rate_penalty = lambda_B + weighted_rate = rate_penalty * total_nbpp + + return weighted_rate, float(rate_penalty) + +def _non_saturating_loss(D_real_logits, D_gen_logits, D_real=None, D_gen=None): + + D_loss_real = F.binary_cross_entropy_with_logits(input=D_real_logits, + target=torch.ones_like(D_real_logits)) + D_loss_gen = F.binary_cross_entropy_with_logits(input=D_gen_logits, + target=torch.zeros_like(D_gen_logits)) + D_loss = D_loss_real + D_loss_gen + + G_loss = F.binary_cross_entropy_with_logits(input=D_gen_logits, + target=torch.ones_like(D_gen_logits)) + + return D_loss, G_loss + +def _least_squares_loss(D_real, D_gen, D_real_logits=None, D_gen_logits=None): + D_loss_real = torch.mean(torch.square(D_real - 1.0)) + D_loss_gen = torch.mean(torch.square(D_gen)) + D_loss = 0.5 * (D_loss_real + D_loss_gen) + + G_loss = 0.5 * torch.mean(torch.square(D_gen - 1.0)) + + return D_loss, G_loss + +def gan_loss(gan_loss_type, disc_out, mode='generator_loss'): + + if gan_loss_type == 'non_saturating': + loss_fn = _non_saturating_loss + elif gan_loss_type == 'least_squares': + loss_fn = _least_squares_loss + else: + raise ValueError('Invalid GAN loss') + + D_loss, G_loss = loss_fn(D_real=disc_out.D_real, D_gen=disc_out.D_gen, + D_real_logits=disc_out.D_real_logits, D_gen_logits=disc_out.D_gen_logits) + + loss = G_loss if mode == 'generator_loss' else D_loss + + return loss diff --git a/loss/perceptual_similarity/__init__.py b/loss/perceptual_similarity/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/loss/perceptual_similarity/base_model.py b/loss/perceptual_similarity/base_model.py new file mode 100644 index 0000000..73882c1 --- /dev/null +++ b/loss/perceptual_similarity/base_model.py @@ -0,0 +1,53 @@ +import os +import torch +from torch.autograd import Variable + +class BaseModel(): + def __init__(self): + pass + + def name(self): + return 'BaseModel' + + def initialize(self, use_gpu=True, gpu_ids=[0]): + self.use_gpu = use_gpu + self.gpu_ids = gpu_ids + + def forward(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, path, network_label, epoch_label): + save_filename = f'{epoch_label}_net_{network_label}' + save_path = os.path.join(path, save_filename) + torch.save(network.state_dict(), save_path) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = f'{epoch_label}_net_{network_label}' + save_path = os.path.join(self.save_dir, save_filename) + print(f'Loading network from {save_path}') + network.load_state_dict(torch.load(save_path)) + + def get_image_paths(self): + return self.image_paths + + def save_done(self, flag=False): + np.save(os.path.join(self.save_dir, 'done_flag'),flag) + np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') + diff --git a/loss/perceptual_similarity/dist_model.py b/loss/perceptual_similarity/dist_model.py new file mode 100644 index 0000000..f728a40 --- /dev/null +++ b/loss/perceptual_similarity/dist_model.py @@ -0,0 +1,280 @@ + +from __future__ import absolute_import + +import sys +import numpy as np +import torch +from torch import nn +import os +from collections import OrderedDict +from torch.autograd import Variable +import itertools +from .base_model import BaseModel +from scipy.ndimage import zoom +import fractions +import functools +import skimage.transform +from tqdm import tqdm + + +from . import networks_basic as networks +from . import perceptual_loss + +class DistModel(BaseModel): + def name(self): + return self.model_name + + def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, + use_gpu=True, printNet=False, spatial=False, + is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): + ''' + INPUTS + model - ['net-lin'] for linearly calibrated network + ['net'] for off-the-shelf network + ['L2'] for L2 distance in Lab colorspace + ['SSIM'] for ssim in RGB colorspace + net - ['squeeze','alex','vgg'] + model_path - if None, will look in weights/[NET_NAME].pth + colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM + use_gpu - bool - whether or not to use a GPU + printNet - bool - whether or not to print network architecture out + spatial - bool - whether to output an array containing varying distances across spatial dimensions + is_train - bool - [True] for training mode + lr - float - initial learning rate + beta1 - float - initial momentum term for adam + version - 0.1 for latest, 0.0 was original (with a bug) + gpu_ids - int array - [0] by default, gpus to use + ''' + BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) + + self.model = model + self.net = net + self.is_train = is_train + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model_name = '%s [%s]'%(model,net) + + if(self.model == 'net-lin'): # pretrained net + linear layer + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, + use_dropout=True, spatial=spatial, version=version, lpips=True) + kw = {} + if not use_gpu: + kw['map_location'] = 'cpu' + if(model_path is None): + import inspect + model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) + + if(not is_train): + print('Loading model from: %s'%model_path) + self.net.load_state_dict(torch.load(model_path, **kw), strict=False) + + elif(self.model=='net'): # pretrained network + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) + elif(self.model in ['L2','l2']): + self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing + self.model_name = 'L2' + elif(self.model in ['DSSIM','dssim','SSIM','ssim']): + self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) + self.model_name = 'SSIM' + else: + raise ValueError("Model [%s] not recognized." % self.model) + + self.parameters = list(self.net.parameters()) + + if self.is_train: # training mode + # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) + self.rankLoss = networks.BCERankingLoss() + self.parameters += list(self.rankLoss.net.parameters()) + self.lr = lr + self.old_lr = lr + self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) + else: # test mode + self.net.eval() + + if(use_gpu): + self.net.to(gpu_ids[0]) + self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) + if(self.is_train): + self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 + + if(printNet): + print('---------- Networks initialized -------------') + networks.print_network(self.net) + print('-----------------------------------------------') + + def forward(self, in0, in1, retPerLayer=False): + ''' Function computes the distance between image patches in0 and in1 + INPUTS + in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] + OUTPUT + computed distances between in0 and in1 + ''' + + return self.net.forward(in0, in1, retPerLayer=retPerLayer) + + # ***** TRAINING FUNCTIONS ***** + def optimize_parameters(self): + self.forward_train() + self.optimizer_net.zero_grad() + self.backward_train() + self.optimizer_net.step() + self.clamp_weights() + + def clamp_weights(self): + for module in self.net.modules(): + if(hasattr(module, 'weight') and module.kernel_size==(1,1)): + module.weight.data = torch.clamp(module.weight.data,min=0) + + def set_input(self, data): + self.input_ref = data['ref'] + self.input_p0 = data['p0'] + self.input_p1 = data['p1'] + self.input_judge = data['judge'] + + if(self.use_gpu): + self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) + self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) + self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) + self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) + + self.var_ref = Variable(self.input_ref,requires_grad=True) + self.var_p0 = Variable(self.input_p0,requires_grad=True) + self.var_p1 = Variable(self.input_p1,requires_grad=True) + + def forward_train(self): # run forward pass + # print(self.net.module.scaling_layer.shift) + # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) + + self.d0 = self.forward(self.var_ref, self.var_p0) + self.d1 = self.forward(self.var_ref, self.var_p1) + self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) + + self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) + + self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) + + return self.loss_total + + def backward_train(self): + torch.mean(self.loss_total).backward() + + def compute_accuracy(self,d0,d1,judge): + ''' d0, d1 are Variables, judge is a Tensor ''' + d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) + self.old_lr = lr + +def score_2afc_dataset(data_loader, func, name=''): + ''' Function computes Two Alternative Forced Choice (2AFC) score using + distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return numpy array of length N + OUTPUTS + [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators + [1] - dictionary with following elements + d0s,d1s - N arrays containing distances between reference patch to perturbed patches + gts - N array in [0,1], preferred patch selected by human evaluators + (closer to "0" for left patch p0, "1" for right patch p1, + "0.6" means 60pct people preferred right patch, 40pct preferred left) + scores - N array in [0,1], corresponding to what percentage function agreed with humans + CONSTS + N - number of test triplets in data_loader + ''' + + d0s = [] + d1s = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() + d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() + gts+=data['judge'].cpu().numpy().flatten().tolist() + + d0s = np.array(d0s) + d1s = np.array(d1s) + gts = np.array(gts) + scores = (d0s= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): +# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): +# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) diff --git a/loss/perceptual_similarity/pretrained_networks.py b/loss/perceptual_similarity/pretrained_networks.py new file mode 100644 index 0000000..a70ebbe --- /dev/null +++ b/loss/perceptual_similarity/pretrained_networks.py @@ -0,0 +1,180 @@ +from collections import namedtuple +import torch +from torchvision import models as tv + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2,5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) + out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + if(num==18): + self.net = tv.resnet18(pretrained=pretrained) + elif(num==34): + self.net = tv.resnet34(pretrained=pretrained) + elif(num==50): + self.net = tv.resnet50(pretrained=pretrained) + elif(num==101): + self.net = tv.resnet101(pretrained=pretrained) + elif(num==152): + self.net = tv.resnet152(pretrained=pretrained) + self.N_slices = 5 + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out diff --git a/loss/perceptual_similarity/weights/v0.1/alex.pth b/loss/perceptual_similarity/weights/v0.1/alex.pth new file mode 100644 index 0000000..1df9dfe Binary files /dev/null and b/loss/perceptual_similarity/weights/v0.1/alex.pth differ diff --git a/loss/perceptual_similarity/weights/v0.1/squeeze.pth b/loss/perceptual_similarity/weights/v0.1/squeeze.pth new file mode 100644 index 0000000..a3bd383 Binary files /dev/null and b/loss/perceptual_similarity/weights/v0.1/squeeze.pth differ diff --git a/loss/perceptual_similarity/weights/v0.1/vgg.pth b/loss/perceptual_similarity/weights/v0.1/vgg.pth new file mode 100644 index 0000000..47e943c Binary files /dev/null and b/loss/perceptual_similarity/weights/v0.1/vgg.pth differ diff --git a/network/AutoEncoder.py b/network/AutoEncoder.py new file mode 100644 index 0000000..b2de7ea --- /dev/null +++ b/network/AutoEncoder.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn + + +if __name__ == "__main__": + from encoder import Encoder + from generator import Generator +else: + from .encoder import Encoder + from .generator import Generator + + +class AutoEncoder(nn.Module): + def __init__(self, image_dims, batch_size, C=20, activation='relu', n_residual_blocks=8, channel_norm=True): + super(AutoEncoder, self).__init__() + + self.encoder = Encoder(image_dims, batch_size, activation, C, channel_norm) + + # 计算encoder输出的维度 + with torch.no_grad(): + dummy_input = torch.zeros(1, *image_dims) + encoder_output = self.encoder(dummy_input) + encoder_output_dims = encoder_output.shape[1:] + + self.generator = Generator(encoder_output_dims, batch_size, C, activation, n_residual_blocks, channel_norm) + + def forward(self, x): + encoded = self.encoder(x) + decoded = self.generator(encoded) + return decoded + +# 测试代码 +if __name__ == "__main__": + batch_size = 32 + image_dims = (3, 96, 96) # (channels, height, width) + + model = AutoEncoder(image_dims, batch_size) + + # 创建一个随机输入张量来测试模型 + dummy_input = torch.randn(batch_size, *image_dims) + + output = model(dummy_input) + print(f"Input shape: {dummy_input.shape}") + print(f"Output shape: {output.shape}") + + # 确保输入和输出的形状相同 + assert dummy_input.shape == output.shape, "Input and output shapes do not match!" + print("AutoEncoder test passed successfully!") \ No newline at end of file diff --git a/network/__init__.py b/network/__init__.py new file mode 100644 index 0000000..d4d788a --- /dev/null +++ b/network/__init__.py @@ -0,0 +1 @@ +# Model / loss definitions diff --git a/network/channel.py b/network/channel.py new file mode 100644 index 0000000..98a8eee --- /dev/null +++ b/network/channel.py @@ -0,0 +1,59 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import Parameter + +def InstanceNorm2D_wrap(input_channels, momentum=0.1, affine=True, + track_running_stats=False, **kwargs): + """ + Wrapper around default Torch instancenorm + """ + instance_norm_layer = nn.InstanceNorm2d(input_channels, + momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + return instance_norm_layer + +def ChannelNorm2D_wrap(input_channels, momentum=0.1, affine=True, + track_running_stats=False, **kwargs): + """ + Wrapper around Channel Norm module + """ + channel_norm_layer = ChannelNorm2D(input_channels, + momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + return channel_norm_layer + +class ChannelNorm2D(nn.Module): + """ + Similar to default Torch instanceNorm2D but calculates + moments over channel dimension instead of spatial dims. + Expects input_dim in format (B,C,H,W) + """ + + def __init__(self, input_channels, momentum=0.1, eps=1e-3, + affine=True, **kwargs): + super(ChannelNorm2D, self).__init__() + + self.momentum = momentum + self.eps = eps + self.affine = affine + + if affine is True: + self.gamma = nn.Parameter(torch.ones(1, input_channels, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, input_channels, 1, 1)) + + def forward(self, x): + """ + Calculate moments over channel dim, normalize. + x: Image tensor, shape (B,C,H,W) + """ + mu, var = torch.mean(x, dim=1, keepdim=True), torch.var(x, dim=1, keepdim=True) + + x_normed = (x - mu) * torch.rsqrt(var + self.eps) + + if self.affine is True: + x_normed = self.gamma * x_normed + self.beta + return x_normed diff --git a/network/discriminator.py b/network/discriminator.py new file mode 100644 index 0000000..6dd29c9 --- /dev/null +++ b/network/discriminator.py @@ -0,0 +1,95 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class Discriminator(nn.Module): + def __init__(self, image_dims, context_dims, C, spectral_norm=True): + """ + Convolutional patchGAN discriminator used in [1]. + Accepts as input generator output G(z) or x ~ p*(x) where + p*(x) is the true data distribution. + Contextual information provided is encoder output y = E(x) + ======== + Arguments: + image_dims: Dimensions of input image, (C_in,H,W) + context_dims: Dimensions of contextual information, (C_in', H', W') + C: Bottleneck depth, controls bits-per-pixel + C = 220 used in [1], C = C_in' if encoder output used + as context. + + [1] Mentzer et. al., "High-Fidelity Generative Image Compression", + arXiv:2006.09965 (2020). + """ + super(Discriminator, self).__init__() + + self.image_dims = image_dims + self.context_dims = context_dims + im_channels = self.image_dims[0] + kernel_dim = 4 + context_C_out = 12 + filters = (64, 128, 256, 512) + + # Upscale encoder output - (C, 16, 16) -> (12, 256, 256) + self.context_conv = nn.Conv2d(C, context_C_out, kernel_size=3, padding=1, padding_mode='reflect') + self.context_upsample = nn.Upsample(scale_factor=16, mode='nearest') + + # Images downscaled to 500 x 1000 + randomly cropped to 256 x 256 + # assert image_dims == (im_channels, 256, 256), 'Crop image to 256 x 256!' + + # Layer / normalization options + # TODO: calculate padding properly + cnn_kwargs = dict(stride=2, padding=1, padding_mode='reflect') + self.activation = nn.LeakyReLU(negative_slope=0.2) + + if spectral_norm is True: + norm = nn.utils.spectral_norm + else: + norm = nn.utils.weight_norm + + # (C_in + C_in', 256,256) -> (64,128,128), with implicit padding + # TODO: Check if removing spectral norm in first layer works + self.conv1 = norm(nn.Conv2d(im_channels + context_C_out, filters[0], kernel_dim, **cnn_kwargs)) + + # (128,128) -> (64,64) + self.conv2 = norm(nn.Conv2d(filters[0], filters[1], kernel_dim, **cnn_kwargs)) + + # (64,64) -> (32,32) + self.conv3 = norm(nn.Conv2d(filters[1], filters[2], kernel_dim, **cnn_kwargs)) + + # (32,32) -> (16,16) + self.conv4 = norm(nn.Conv2d(filters[2], filters[3], kernel_dim, **cnn_kwargs)) + + self.conv_out = nn.Conv2d(filters[3], 1, kernel_size=1, stride=1) + + def forward(self, x, y): + """ + x: Concatenated real/gen images + y: Quantized latents + """ + batch_size = x.size()[0] + + # Concatenate upscaled encoder output y as contextual information + y = self.activation(self.context_conv(y)) + y = self.context_upsample(y) + + x = torch.cat((x,y), dim=1) + x = self.activation(self.conv1(x)) + x = self.activation(self.conv2(x)) + x = self.activation(self.conv3(x)) + x = self.activation(self.conv4(x)) + + out_logits = self.conv_out(x).view(-1,1) + out = torch.sigmoid(out_logits) + + return out, out_logits + +if __name__ == "__main__": + B = 2 + C = 7 + print('Image 1') + x = torch.randn((B,3,256,256)) + x_dims = tuple(x.size()) + D = Discriminator(image_dims=x_dims[1:], context_dims=tuple(x.size())[1:], C=C) + print('Discriminator output', x.size()) diff --git a/network/encoder.py b/network/encoder.py new file mode 100644 index 0000000..df526eb --- /dev/null +++ b/network/encoder.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +if __name__ == "__main__": + import channel, instance +else: + from . import channel, instance + + + +class Encoder(nn.Module): + def __init__(self, image_dims, batch_size, activation='relu', C=220, + channel_norm=True): + """ + Encoder with convolutional architecture proposed in [1]. + Projects image x ([C_in,256,256]) into a feature map of size C x W/16 x H/16 + ======== + Arguments: + image_dims: Dimensions of input image, (C_in,H,W) + batch_size: Number of instances per minibatch + C: Bottleneck depth, controls bits-per-pixel + C = {2,4,8,16} + + [1] Mentzer et. al., "High-Fidelity Generative Image Compression", + arXiv:2006.09965 (2020). + """ + + super(Encoder, self).__init__() + + kernel_dim = 3 + filters = (15, 30, 60, 120, 240) + + # Images downscaled to 500 x 1000 + randomly cropped to 256 x 256 + im_channels = image_dims[0] + # assert image_dims == (im_channels, 256, 256), 'Crop image to 256 x 256!' + + # Layer / normalization options + cnn_kwargs = dict(stride=2, padding=0, padding_mode='reflect') + norm_kwargs = dict(momentum=0.1, affine=True, track_running_stats=False) + activation_d = dict(relu='ReLU', elu='ELU', leaky_relu='LeakyReLU') + self.activation = getattr(nn, activation_d[activation]) # (leaky_relu, relu, elu) + self.n_downsampling_layers = 4 + + if channel_norm is True: + self.interlayer_norm = channel.ChannelNorm2D_wrap + else: + self.interlayer_norm = instance.InstanceNorm2D_wrap + + self.pre_pad = nn.ReflectionPad2d(3) + self.asymmetric_pad = nn.ReflectionPad2d((0,1,1,0)) # Slower than tensorflow? + self.post_pad = nn.ReflectionPad2d(1) + + heights = [2**i for i in range(4,9)][::-1] + widths = heights + H1, H2, H3, H4, H5 = heights + W1, W2, W3, W4, W5 = widths + + # (256,256) -> (256,256), with implicit padding + self.conv_block1 = nn.Sequential( + self.pre_pad, + nn.Conv2d(im_channels, filters[0], kernel_size=(7,7), stride=1), + self.interlayer_norm(filters[0], **norm_kwargs), + self.activation(), + ) + + # (256,256) -> (128,128) + self.conv_block2 = nn.Sequential( + self.asymmetric_pad, + nn.Conv2d(filters[0], filters[1], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[1], **norm_kwargs), + self.activation(), + ) + + # (128,128) -> (64,64) + self.conv_block3 = nn.Sequential( + self.asymmetric_pad, + nn.Conv2d(filters[1], filters[2], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[2], **norm_kwargs), + self.activation(), + ) + + # (64,64) -> (32,32) + self.conv_block4 = nn.Sequential( + self.asymmetric_pad, + nn.Conv2d(filters[2], filters[3], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[3], **norm_kwargs), + self.activation(), + ) + + # (32,32) -> (16,16) + self.conv_block5 = nn.Sequential( + self.asymmetric_pad, + nn.Conv2d(filters[3], filters[4], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[4], **norm_kwargs), + self.activation(), + ) + + # Project channels onto space w/ dimension C + # Feature maps have dimension C x W/16 x H/16 + # (16,16) -> (16,16) + self.conv_block_out = nn.Sequential( + self.post_pad, + nn.Conv2d(filters[4], C, kernel_dim, stride=1), + ) + + + def forward(self, x): + x = self.conv_block1(x) + x = self.conv_block2(x) + x = self.conv_block3(x) + x = self.conv_block4(x) + x = self.conv_block5(x) + out = self.conv_block_out(x) + return out + + +if __name__ == "__main__": + B = 32 + C = 220 + print('Image 1') + x = torch.randn((B,3,96,96)) + x_dims = tuple(x.size()) + E = Encoder(image_dims=x_dims[1:], batch_size=B, C=C) + print(E(x).size()) + diff --git a/network/generator.py b/network/generator.py new file mode 100644 index 0000000..ec56434 --- /dev/null +++ b/network/generator.py @@ -0,0 +1,195 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +from torch.nn import GELU + + +if __name__ == "__main__": + import channel, instance +else: + from . import channel, instance + +class ResidualBlock(nn.Module): + def __init__(self, input_dims, kernel_size=3, stride=1, + channel_norm=True, activation='relu'): + """ + input_dims: Dimension of input tensor (B,C,H,W) + """ + super(ResidualBlock, self).__init__() + + self.activation = getattr(F, activation) + in_channels = input_dims[1] + norm_kwargs = dict(momentum=0.1, affine=True, track_running_stats=False) + + if channel_norm is True: + self.interlayer_norm = channel.ChannelNorm2D_wrap + else: + self.interlayer_norm = instance.InstanceNorm2D_wrap + + pad_size = int((kernel_size-1)/2) + self.pad = nn.ReflectionPad2d(pad_size) + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride) + self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride) + self.norm1 = self.interlayer_norm(in_channels, **norm_kwargs) + self.norm2 = self.interlayer_norm(in_channels, **norm_kwargs) + + def forward(self, x): + identity_map = x + res = self.pad(x) + res = self.conv1(res) + res = self.norm1(res) + res = self.activation(res) + + res = self.pad(res) + res = self.conv2(res) + res = self.norm2(res) + + return torch.add(res, identity_map) + +class Generator(nn.Module): + def __init__(self, input_dims, batch_size, C=16, activation='relu', + n_residual_blocks=8, channel_norm=True, sample_noise=False, + noise_dim=32): + + """ + Generator with convolutional architecture proposed in [1]. + Upscales quantized encoder output into feature map of size C x W x H. + Expects input size (C,16,16) + ======== + Arguments: + input_dims: Dimensions of quantized representation, (C,H,W) + batch_size: Number of instances per minibatch + C: Encoder bottleneck depth, controls bits-per-pixel + C = 220 used in [1]. + + [1] Mentzer et. al., "High-Fidelity Generative Image Compression", + arXiv:2006.09965 (2020). + """ + + super(Generator, self).__init__() + + kernel_dim = 3 + filters = [240, 120, 60,30,15] + self.n_residual_blocks = n_residual_blocks + self.sample_noise = sample_noise + self.noise_dim = noise_dim + + # Layer / normalization options + cnn_kwargs = dict(stride=2, padding=1, output_padding=1) + norm_kwargs = dict(momentum=0.1, affine=True, track_running_stats=False) + activation_d = dict(relu='ReLU', elu='ELU', leaky_relu='LeakyReLU', gelu='GELU') + self.activation = getattr(nn, activation_d[activation]) # (leaky_relu, relu, elu, gelu) + self.n_upsampling_layers = 4 + + if channel_norm is True: + self.interlayer_norm = channel.ChannelNorm2D_wrap + else: + self.interlayer_norm = instance.InstanceNorm2D_wrap + + self.pre_pad = nn.ReflectionPad2d(1) + self.asymmetric_pad = nn.ReflectionPad2d((0,1,1,0)) # Slower than tensorflow? + self.post_pad = nn.ReflectionPad2d(3) + + H0, W0 = input_dims[1:] + heights = [2**i for i in range(5,9)] + widths = heights + H1, H2, H3, H4 = heights + W1, W2, W3, W4 = widths + + + # (16,16) -> (16,16), with implicit padding + self.conv_block_init = nn.Sequential( + self.interlayer_norm(C, **norm_kwargs), + self.pre_pad, + nn.Conv2d(C, filters[0], kernel_size=(3,3), stride=1), + self.interlayer_norm(filters[0], **norm_kwargs), + ) + + if sample_noise is True: + # Concat noise with latent representation + filters[0] += self.noise_dim + + for m in range(n_residual_blocks): + resblock_m = ResidualBlock(input_dims=(batch_size, filters[0], H0, W0), + channel_norm=channel_norm, activation=activation) + self.add_module(f'resblock_{str(m)}', resblock_m) + + # (16,16) -> (32,32) + self.upconv_block1 = nn.Sequential( + nn.ConvTranspose2d(filters[0], filters[1], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[1], **norm_kwargs), + self.activation(), + ) + + self.upconv_block2 = nn.Sequential( + nn.ConvTranspose2d(filters[1], filters[2], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[2], **norm_kwargs), + self.activation(), + ) + + self.upconv_block3 = nn.Sequential( + nn.ConvTranspose2d(filters[2], filters[3], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[3], **norm_kwargs), + self.activation(), + ) + + self.upconv_block4 = nn.Sequential( + nn.ConvTranspose2d(filters[3], filters[4], kernel_dim, **cnn_kwargs), + self.interlayer_norm(filters[4], **norm_kwargs), + self.activation(), + ) + + self.conv_block_out = nn.Sequential( + self.post_pad, + nn.Conv2d(filters[-1], 3, kernel_size=(7,7), stride=1), + ) + + + def forward(self, x): + + head = self.conv_block_init(x) + + if self.sample_noise is True: + B, C, H, W = tuple(head.size()) + z = torch.randn((B, self.noise_dim, H, W)).to(head) + head = torch.cat((head,z), dim=1) + + for m in range(self.n_residual_blocks): + resblock_m = getattr(self, f'resblock_{str(m)}') + if m == 0: + x = resblock_m(head) + else: + x = resblock_m(x) + + x += head + x = self.upconv_block1(x) + x = self.upconv_block2(x) + x = self.upconv_block3(x) + x = self.upconv_block4(x) + out = self.conv_block_out(x) + + return out + + +if __name__ == "__main__": + import os + + C = 128 + y = torch.randn([10,C,2,2]) + y_dims = y.size() + G = Generator(y_dims[1:], y_dims[0], C=C, n_residual_blocks=3, sample_noise=True) + + x_hat = G(y) + print(f"Output size: {x_hat.size()}") + + # Save the model + torch.save(G.state_dict(), "generator_model.pth") + + # Get the file size + file_size = os.path.getsize("generator_model.pth") + print(f"Generator model size on disk: {file_size / 1024:.2f} KB") + + # Optionally, remove the file after checking its size + os.remove("generator_model.pth") \ No newline at end of file diff --git a/network/hyper.py b/network/hyper.py new file mode 100644 index 0000000..82f0851 --- /dev/null +++ b/network/hyper.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.helpers import maths +lower_bound_toward = maths.LowerBoundToward.apply + +def get_num_DLMM_channels(C, K=4, params=['mu','scale','mix']): + """ + C: Channels of latent representation (L3C uses 5). + K: Number of mixture coefficients. + """ + return C * K * len(params) + +def get_num_mixtures(K_agg, C, params=['mu','scale','mix']): + return K_agg // (len(params) * C) + +def unpack_likelihood_params(x, conv_out, log_scales_min): + + N, C, H, W = x.shape + K_agg = conv_out.shape[1] + + K = get_num_mixtures(K_agg, C) + + # For each channel: K pi / K mu / K sigma + conv_out = conv_out.reshape(N, 3, C, K, H, W) + logit_pis = conv_out[:, 0, ...] + means = conv_out[:, 1, ...] + log_scales = conv_out[:, 2, ...] + log_scales = lower_bound_toward(log_scales, log_scales_min) + x = x.reshape(N, C, 1, H, W) + + return x, (logit_pis, means, log_scales), K + + +class HyperpriorAnalysis(nn.Module): + """ + Hyperprior 'analysis model' as proposed in [1]. + + [1] Ballé et. al., "Variational image compression with a scale hyperprior", + arXiv:1802.01436 (2018). + + C: Number of input channels + """ + def __init__(self, C=220, N=320, activation='relu'): + super(HyperpriorAnalysis, self).__init__() + + cnn_kwargs = dict(kernel_size=5, stride=2, padding=2, padding_mode='reflect') + self.activation = getattr(F, activation) + self.n_downsampling_layers = 2 + + self.conv1 = nn.Conv2d(C, N, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(N, N, **cnn_kwargs) + self.conv3 = nn.Conv2d(N, N, **cnn_kwargs) + + def forward(self, x): + + # x = torch.abs(x) + x = self.activation(self.conv1(x)) + x = self.activation(self.conv2(x)) + x = self.conv3(x) + + return x + + +class HyperpriorSynthesis(nn.Module): + """ + Hyperprior 'synthesis model' as proposed in [1]. Outputs + distribution parameters of input latents. + + [1] Ballé et. al., "Variational image compression with a scale hyperprior", + arXiv:1802.01436 (2018). + + C: Number of output channels + """ + def __init__(self, C=220, N=320, activation='relu', final_activation=None): + super(HyperpriorSynthesis, self).__init__() + + cnn_kwargs = dict(kernel_size=5, stride=2, padding=2, output_padding=1) + self.activation = getattr(F, activation) + self.final_activation = final_activation + + self.conv1 = nn.ConvTranspose2d(N, N, **cnn_kwargs) + self.conv2 = nn.ConvTranspose2d(N, N, **cnn_kwargs) + self.conv3 = nn.ConvTranspose2d(N, C, kernel_size=3, stride=1, padding=1) + + if self.final_activation is not None: + self.final_activation = getattr(F, final_activation) + + def forward(self, x): + x = self.activation(self.conv1(x)) + x = self.activation(self.conv2(x)) + x = self.conv3(x) + + if self.final_activation is not None: + x = self.final_activation(x) + return x + + +class HyperpriorSynthesisDLMM(nn.Module): + """ + Outputs distribution parameters of input latents, conditional on + hyperlatents, assuming a discrete logistic mixture model. + + C: Number of output channels + """ + def __init__(self, C=64, N=320, activation='relu', final_activation=None): + super(HyperpriorSynthesisDLMM, self).__init__() + + cnn_kwargs = dict(kernel_size=5, stride=2, padding=2, output_padding=1) + self.activation = getattr(F, activation) + self.final_activation = final_activation + + self.conv1 = nn.ConvTranspose2d(N, N, **cnn_kwargs) + self.conv2 = nn.ConvTranspose2d(N, N, **cnn_kwargs) + self.conv3 = nn.ConvTranspose2d(N, C, kernel_size=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(C, get_num_DLMM_channels(C), kernel_size=1, stride=1) + + if self.final_activation is not None: + self.final_activation = getattr(F, final_activation) + + def forward(self, x): + x = self.activation(self.conv1(x)) + x = self.activation(self.conv2(x)) + x = self.conv3(x) + x = self.conv_out(x) + + if self.final_activation is not None: + x = self.final_activation(x) + return x diff --git a/network/instance.py b/network/instance.py new file mode 100644 index 0000000..37595c9 --- /dev/null +++ b/network/instance.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import Parameter + +def InstanceNorm2D_wrap(input_channels, momentum=0.1, affine=True, + track_running_stats=False, **kwargs): + """ + Wrapper around default Torch instancenorm + """ + instance_norm_layer = nn.InstanceNorm2d(input_channels, + momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + return instance_norm_layer