Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s47036219 ADNI VQ-VAE Connor Armstrong #161

Open
wants to merge 20 commits into
base: topic-recognition
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

*.pth
*.pyc
70 changes: 70 additions & 0 deletions recognition/vq-vae_s47036219/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# VQ-VAE for the ADNI Dataset

**Author**: Connor Armstrong (s4703621)


# Project:

## The Vector Quantized Variational Autoencoder
The goal of this task was to implement a Vector Quantized Variational Autoencoder (henceforth referred to as a VQ-VAE). The VQ-VAE is an extension of a typical variational autoencoder that handles discrete latent representation learning - which is where the model learns to represent data in a form where the latent variables take on distinct discrete values, rather than a continuous range. This is done by the model passing the encoders output through a vector quantisation layer, mapping the continuous encodings to the closest vector in the embedding space. This makes the VQ-VAE very effective at managing discrete structured data and image reconstruction/generation.


## VQ-VAE Architecture
![VQ-VAE Structure](./vqvae_structure.jpg)

As shown above, the VQ-VAE is comprised of a few important components:

- **Encoder**:
The encoder takes in an input, represented by `x`, and compresses it into a continuous latent space resulting in `Z_e(x)`.

- **Latent Predictor p(z)**:
This is not necessarily an actual module as in most VQ-VAE architectures, this isn't explicitly present. However, it is useful to think that the latent space has some underlying probability distribution `p(z)` which the model tries to capture or mimic.

- **Nearest Neighbors & Codebook**:
One of the most important features of VQ-VAE is the use of a discrete codebook. Each entry in the codebook is a vector. The continuous output from the encoder (`Z_e(x)`) is mapped to the nearest vector in this codebook. This is represented by the table at the bottom. Each row is a unique vector in the codebook. The process of mapping `Z_e(x)` to the nearest codebook vector results in `Z_q(x)`, a quantized version of the encoder's output.

- **Decoder**:
The decoder takes the quantized latent representation `Z_q(x)` and reconstructs the original input, producing `x'`. Ideally, `x'` should be a close approximation of the original input `x`.

The use of a discrete codebook in the latent space (instead of a continuous one) allows the VQ-VAE to capture more complex data distributions with fewer latent variables.



## VQ-VAE and the ADNI Dataset
The ADNI (Alzheimer’s Disease Neuroimaging Initiative) dataset is a collection of neuroimaging data, curated with the primary intent of studying Alzheimer's disease. In the context of the ADNI dataset, a VQ-VAE can be applied to condense complex brain scans into a more manageable, lower-dimensional, discrete latent space. By doing so, it can effectively capture meaningful patterns and structures inherent in the images.


## Details on the implementation

The goal of this project was to: "Ceate a generative model of the ADNI brain dataset using a VQVA that has a “reasonably clear image” and a Structured Similarity (SSIM) of over 0.6"

This implementation was relatively standard for this model. There exist other extensions that could be of a great use in this case, using a gan or other generative models in combination creates a powerful method to improve upon my implementation - but this is left forr other students with more time.

# Usage:
**Please Note: Before running please add the directory to the train and test files for the dataset in 'train.py'**

It is highly reccomended to run only the 'predict.py' file by calling 'python predict.py' while in the working directory. It is possible to run from the 'train.py' file as well, but this has implications with data leakage a I could not find a proper way to partition the test set.

If all goes well, matplotlib outputs 4 images: the original and reconstructed brain with the highest ssim, and then the lowest ssim.

# Data
This project uses the ADNI dataset (in the structure as seen on blackboard), where the training set is used to train the model, and the test folder is partitioned into a validation set and test set.


# Dependencies
| Dependency | Version |
|-------------|-------------|
| torch | 2.0.1+cu117 |
| torchvision | 0.15.2+cu117|
| matplotlib | 3.8.0 |

# Output
As stated earlier, these are the images with the highest and lowest ssim scores:
![Output Image](./output.png)

# References
The following sources inspired my implementation and were referenced in order to complete this project:
* Neural Discrete Representation Learning, Aaron van den Oord, Oriol Vinyals, Koray Kavukcuoglu, 2017. https://arxiv.org/abs/1711.00937
* Adni Brain Dataset, Thanks to https://adni.loni.usc.edu/
* Misha Laskin, https://github.com/MishaLaskin/vqvae/tree/master
* Aurko Roy et al., Theory and Experiments on Vector Quantized Autoencoders, https://www.arxiv-vanity.com/papers/1805.11063/
23 changes: 23 additions & 0 deletions recognition/vq-vae_s47036219/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

def get_dataloaders(train_string, test_validation_string, batch_size):
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((64, 64)),
transforms.ToTensor(),
#transforms.Normalize(mean=[0.5], std=[0.5])
])
train_dataset = datasets.ImageFolder(root=train_string, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

full_test_dataset = datasets.ImageFolder(root=test_validation_string, transform=transform)
test_size = int(0.3 * len(full_test_dataset))
val_size = len(full_test_dataset) - test_size

test_dataset, val_dataset = random_split(full_test_dataset, [test_size, val_size])

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

return train_loader, val_loader, test_loader
132 changes: 132 additions & 0 deletions recognition/vq-vae_s47036219/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import torch
import torch.nn as nn




class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, intermediate_channels=None):
super(ResidualBlock, self).__init__()

if not intermediate_channels:
intermediate_channels = in_channels // 2

self._residual_block = nn.Sequential(
nn.ReLU(),
nn.Conv2d(in_channels, intermediate_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.ReLU(),
nn.Conv2d(intermediate_channels, out_channels, kernel_size=1, stride=1, bias=False)
)

def forward(self, x):
return x + self._residual_block(x)


class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()

self.layers = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Dropout(0.5),

nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Dropout(0.5),

ResidualBlock(64, 64),
ResidualBlock(64, 64)
)

def forward(self, x):
out = self.layers(x)
return out

class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super(VectorQuantizer, self).__init__()

self.num_embeddings = num_embeddings # Save as an instance variable
self.embedding = nn.Embedding(self.num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1./self.num_embeddings, 1./self.num_embeddings)

def forward(self, x):
batch_size, channels, height, width = x.shape
x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, channels)

# Now x_flat is [batch_size * height * width, channels]

# Calculate distances
distances = ((x_flat.unsqueeze(1) - self.embedding.weight.unsqueeze(0)) ** 2).sum(-1)

# Find the closest embeddings
_, indices = distances.min(1)
encodings = torch.zeros_like(distances).scatter_(1, indices.unsqueeze(1), 1)

# Quantize the input image
quantized = self.embedding(indices)

# Reshape the quantized tensor to the same shape as the input
quantized = quantized.view(batch_size, height, width, channels).permute(0, 3, 1, 2)

return quantized

class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()

self.layers = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Dropout(0.5),

ResidualBlock(64, 64),
ResidualBlock(64, 64),

nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Dropout(0.5),

nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1)
)

def forward(self, x):
return self.layers(x)

class VQVAE(nn.Module):
def __init__(self, num_embeddings=512, embedding_dim=64):
super(VQVAE, self).__init__()

self.encoder = Encoder()
self.conv1 = nn.Conv2d(64, embedding_dim, kernel_size=1, stride=1)
self.vector_quantizer = VectorQuantizer(num_embeddings, embedding_dim)
self.decoder = Decoder()

def forward(self, x):
enc = self.encoder(x)
enc = self.conv1(enc)
quantized = self.vector_quantizer(enc)

dec = self.decoder(quantized)
return dec


def ssim(img1, img2, C1=0.01**2, C2=0.03**2):
mu1 = img1.mean(dim=[2, 3], keepdim=True)
mu2 = img2.mean(dim=[2, 3], keepdim=True)

sigma1_sq = (img1 - mu1).pow(2).mean(dim=[2, 3], keepdim=True)
sigma2_sq = (img2 - mu2).pow(2).mean(dim=[2, 3], keepdim=True)
sigma12 = ((img1 - mu1)*(img2 - mu2)).mean(dim=[2, 3], keepdim=True)

ssim_n = (2*mu1*mu2 + C1) * (2*sigma12 + C2)
ssim_d = (mu1.pow(2) + mu2.pow(2) + C1) * (sigma1_sq + sigma2_sq + C2)

ssim_val = ssim_n / ssim_d

return ssim_val.mean()
Binary file added recognition/vq-vae_s47036219/output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
102 changes: 102 additions & 0 deletions recognition/vq-vae_s47036219/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch
from modules import VQVAE, ssim
from dataset import get_dataloaders
from train import SSIM_WEIGHT, L2_WEIGHT, BATCH_SIZE, train_new_model, path_to_training_folder, path_to_test_folder
import matplotlib
import matplotlib.pyplot as plt
import os

def evaluate(test_loader):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VQVAE().to(device)
model.load_state_dict(torch.load('vqvae.pth'))
model.eval()
print("loaded")
highest_ssim_val = float('-inf') # Initialize with negative infinity
lowest_ssim_val = float('inf') # Initialize with positive infinity
highest_ssim_img = None
highest_ssim_recon = None
lowest_ssim_img = None
lowest_ssim_recon = None

val_losses = []
ssim_sum = 0 # To keep track of sum of all SSIM values
total_images = 0 # To keep track of total number of images processed

with torch.no_grad():
for i, (img, _) in enumerate(test_loader):
img = img.to(device)

# Validation forward pass
z = model.encoder(img)
z = model.conv1(z)
z_q = model.vector_quantizer(z)
recon = model.decoder(z_q)

# Validation losses
l2_loss = ((recon - img) ** 2).sum()
ssim_loss = 1 - ssim(img, recon)
loss = L2_WEIGHT * l2_loss + SSIM_WEIGHT * ssim_loss
val_losses.append(loss.item())

# Calculate SSIM
ssim_val = ssim(img, recon).item()
ssim_sum += ssim_val # Add SSIM value to the sum
total_images += img.size(0) # Increase the total number of images processed

#print(f'SSIM: {ssim_val}') # Output SSIM value

# Update highest and lowest SSIM values and corresponding images
if ssim_val > highest_ssim_val:
highest_ssim_val = ssim_val
highest_ssim_img = img.cpu().numpy().squeeze(1)
highest_ssim_recon = recon.cpu().numpy().squeeze(1)

if ssim_val < lowest_ssim_val:
lowest_ssim_val = ssim_val
lowest_ssim_img = img.cpu().numpy().squeeze(1)
lowest_ssim_recon = recon.cpu().numpy().squeeze(1)

mean_ssim = ssim_sum / total_images
print(f'Mean SSIM: {mean_ssim}') # Output mean SSIM value

# Output images with the highest and lowest SSIM values
plt.figure(figsize=(10, 5))

plt.subplot(2, 2, 1)
plt.title(f'Original Highest SSIM: {highest_ssim_val}')
plt.imshow(highest_ssim_img[0], cmap='gray')

plt.subplot(2, 2, 2)
plt.title('Reconstructed')
plt.imshow(highest_ssim_recon[0], cmap='gray')

plt.subplot(2, 2, 3)
plt.title(f'Original Lowest SSIM: {lowest_ssim_val}')
plt.imshow(lowest_ssim_img[0], cmap='gray')

plt.subplot(2, 2, 4)
plt.title('Reconstructed')
plt.imshow(lowest_ssim_recon[0], cmap='gray')

plt.tight_layout()
plt.show()

def main():
weight_file_path = "vqvae.pth"

train, validate, test = get_dataloaders(path_to_training_folder, path_to_training_folder, BATCH_SIZE)

if os.path.exists(weight_file_path):
print("Weights exist -> Evaluating Model...")
evaluate(test)

else:
print(f"Weight file {weight_file_path} does not exist.")
print("Training model now...")
train_new_model(train, validate)



if __name__ == "__main__":
main()
Loading