Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VIT - S46878467 #176

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 65 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,69 @@
# Pattern Analysis
Pattern Analysis of various datasets by COMP3710 students at the University of Queensland.
# Classifying Alzheimer's Disease Diagnoses Using Vision Trainsformer

We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX.
This project aims to categorize the ADNI brain dataset into AD (Alzheimer's Disease) and NC (Normal Cognitive) groups. It employs a Vision Transformer network (ViT) based on the principles presented in the paper. The model was trained using an Adam Optimizer and the parameters were tweaked often to find a good accuracy. Each sample has 20 slices that is 240x256 greyscale image corresponding to a patient, which is to be classified as either NC or AD

This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students.
## Dataset Splitting

The library includes the following implemented in Tensorflow:
* fractals
* recognition problems
The dataset is already split into 21,500 images for training, and 9000 images for testing. However, I needed a third split for validation in the dataset.py file.
data_val, data_test = random_split(TensorDataset(xtest, ytest), [0.7,0.3])
I used random_split for the validation and did a 70/30 split.
I then ended up with 6300 images for validation, and 2700 for testing.

In the recognition folder, you will find many recognition problems solved including:
* OASIS brain segmentation
* Classification
etc.
## Preprocessing the data
The provided code preprocesses the image data by dividing it into patches, applying layer normalization and Multihead Attention mechanisms, and incorporating positional encoding before utilizing the Vision Transformer

## Training the data
These were the following parameters used for training. I didnt need a parameter for number of channels as we were only dealing with black and white data.
vit = VisionTransformer(input_dimen=128,
hiddenlayer_dimen=256,
number_heads=4,
transform_layers=4,
predict_num=2,
size_patch=(16,16))
input_dimen - Dimensionality of the input feature vectors to the Transformer
hiddenlayer_dimen - Dimensionality of the hidden layer in the feed-forward networks within the Transformer
number_heads - Number of heads to use in the Multi-Head Attention block
transform_layers - Number of layers to use in the Transformer
predict_num - Number of classes to predict
size_patch - Number of pixels that the patches have per dimension

The time taken to finish training depended on the parameters.
Using adam optimizer and learning rate = 1e-4 and 75 epoch, I had accuracy of 0.68 ( 5.5 hours )
With adamW optimizer and learning rate = 3e-4 and 100 epoch, I had a low accuracy of 0.53 ( 7 hours )

## Configuration
All main configurations would be done in the train.py file
In the train function there is this:
optimizer = optim.AdamW(net.parameters(), lr=3e-4)
epochs = 100
You can change between optimizers, learning rate and epoch value in here
Also in the end of the train.py file, there is the VIT.

vit = VisionTransformer(input_dimen=128,
hiddenlayer_dimen=256,
number_heads=4,
transform_layers=4,
predict_num=2,
size_patch=(16,16))

## Results
These are the results:
loss vs epoch graph- ![image](https://github.com/HaadiQureshi/VIT-46878467/assets/141606798/64605a94-429c-4dc8-b5fd-8e4e10276942)


Accuracy vs epoch graph - <img width="596" alt="image" src="https://github.com/HaadiQureshi/VIT-46878467/assets/141606798/4e6fa71b-ec70-482b-bc81-2cf51e819b15">



## How to use
The project consists of four essential files, namely dataset.py, modules.py, train.py, and predict.py. The primary files to be executed are train.py and predict.py. The train.py file handles the training and testing of the model, allowing the option to save the model, along with recording the loss and validation accuracy for each epoch. This data is utilized by predict.py. Predict.py evaluates the actual output data as it can generate graphs depicting the loss and accuracy curves using the matplotlib library.



Key considerations:
1. Inside the dataset.py file, script loads, preprocesses, and organizes medical image data from specific directories, converting the images to tensors, dividing them into training and testing sets with corresponding labels, and creating data loaders for training, testing, and validation.
2. in train.py script imports required libraries, modules, and functions, then loads the data using returnDataLoaders from the dataset.py file. It defines an empty list for storing losses and accuracies, sets up a training function that utilizes the AdamW optimizer and CrossEntropyLoss
3. In the predict.py script, I plot two separate graphs. The first graph illustrates the accuracy vs epoch, displaying the trend of the model's accuracy over the training epochs. The second graph demonstrates the loss vs epoch, showcasing how the training loss varies throughout the training process.
4. The modules.py file contains functions and classes for implementing a Vision Transformer model, including an image patching function, an attention block class for multi-head attention, and a VisionTransformer class that applies linear transformations, positional embeddings
# URL
https://github.com/HaadiQureshi/VIT-46878467.git
89 changes: 89 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
import torch
from PIL import Image
import os
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader, random_split

transform = transforms.Compose([
transforms.PILToTensor()
])

xtrain = []
xtest = []
ytrain = []
ytest = []
slicemax = 20 #20 images per patient

ntrainimgs_AD = 0
patient = []
slice = 0
for filename in sorted(os.listdir('../ADNI_AD_NC_2D/AD_NC/train/AD/')):
f = os.path.join('../ADNI_AD_NC_2D/AD_NC/train/AD/', filename)
img = Image.open(f)
imgtorch = transform(img).float()
imgtorch.require_grad = True
patient.append(imgtorch/255) #go from 0,255 to 0,1
slice = (slice+1) % slicemax
if slice == 0:
xtrain.append(torch.stack(patient))
patient = []
ntrainimgs_AD += 1
pass
ntrainimgs_NC = 0
patient = []
slice = 0
for filename in sorted(os.listdir('../ADNI_AD_NC_2D/AD_NC/train/NC')):
f = os.path.join('../ADNI_AD_NC_2D/AD_NC/train/NC', filename)
img = Image.open(f)
imgtorch = transform(img).float()
imgtorch.require_grad = True
patient.append(imgtorch/255) #go from 0,255 to 0,1
slice = (slice+1) % slicemax
if slice == 0:
xtrain.append(torch.stack(patient))
patient = []
ntrainimgs_NC += 1
pass
ntestimgs_AD = 0
patient = []
slice = 0
for filename in sorted(os.listdir('../ADNI_AD_NC_2D/AD_NC/test/AD')):
f = os.path.join('../ADNI_AD_NC_2D/AD_NC/test/AD', filename)
img = Image.open(f)
imgtorch = transform(img).float()
imgtorch.require_grad = True
patient.append(imgtorch/255) #go from 0,255 to 0,1
slice = (slice+1) % slicemax
if slice == 0:
xtest.append(torch.stack(patient))
patient = []
ntestimgs_AD += 1
pass
ntestimgs_NC = 0
patient = []
slice = 0
for filename in sorted(os.listdir('../ADNI_AD_NC_2D/AD_NC/test/NC')):
f = os.path.join('../ADNI_AD_NC_2D/AD_NC/test/NC', filename)
img = Image.open(f)
imgtorch = transform(img).float()
imgtorch.require_grad = True
patient.append(imgtorch/255) #go from 0,255 to 0,1
slice = (slice+1) % slicemax
if slice == 0:
xtest.append(torch.stack(patient))
patient = []
ntestimgs_NC += 1
pass
xtrain = torch.stack(xtrain)
xtest = torch.stack(xtest)
ytrain = torch.from_numpy(np.concatenate((np.ones(ntrainimgs_AD), np.zeros(ntrainimgs_NC)), axis=0)).type(torch.LongTensor)
ytest = torch.from_numpy(np.concatenate((np.ones(ntestimgs_AD), np.zeros(ntestimgs_NC)), axis=0)).type(torch.LongTensor)

data_val, data_test = random_split(TensorDataset(xtest, ytest), [0.7,0.3])
dataloader_train = DataLoader(TensorDataset(xtrain, ytrain), batch_size=32, shuffle=True)
dataloader_test = DataLoader(data_test, batch_size=32, shuffle=True)
dataloader_val = DataLoader(data_val, batch_size=32, shuffle=True)

def returnDataLoaders():
return dataloader_train, dataloader_test, dataloader_val
73 changes: 73 additions & 0 deletions modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import torch.nn as nn

def image_patcher(image,size_patch, patch_depth):
Batch_Size, Depth, C, Height, Width = image.shape
# change the shape of the tensor
Height_final = Height // size_patch
Width_final = Width // size_patch
Depth_final = Depth // patch_depth
image = image.reshape(Batch_Size, Depth_final, patch_depth, C,Height_final,size_patch,Width_final,size_patch)
#permute the dimensions of the tensor
image = image.permute(0, 1, 4, 6, 3, 2, 5, 7)
#flatten specific dimensions of the tensor
image = image.flatten(1, 3).flatten(2, 5)
return image

class AttentionBlock(nn.Module):
def __init__(self,input_dimen,hiddenlayer_dimen,number_heads):
super().__init__()
#layer normalization is applied to the input data
self.input_layer_norm = nn.LayerNorm(input_dimen)
#normalizes the output of the attention mechanism.
self.output_layer_norm = nn.LayerNorm(input_dimen)
# block with multiple attention heads.
self.multihead_attention = nn.MultiheadAttention(input_dimen,number_heads)
self.linear = nn.Sequential(nn.Linear(input_dimen,hiddenlayer_dimen),nn.GELU(),
nn.Linear(hiddenlayer_dimen,input_dimen),
)

def forward(self,image):
inp_x = self.input_layer_norm(image)
add = self.multihead_attention(inp_x, inp_x, inp_x)[0]
image = image + add
image = image + self.linear(self.output_layer_norm(image))
return image

class VisionTransformer(nn.Module):
def __init__(
self,input_dimen,hiddenlayer_dimen,number_heads,transform_layers,predict_num,size_patch
):
super().__init__()
(size_patch_x, size_patch_y) = size_patch

self.size_patch = size_patch_x * size_patch_y
#creates an instance of the nn.linear
self.input_layer = nn.Linear(5*self.size_patch, input_dimen)
#creates an instance of nn.sequential
self.final_transform = nn.Sequential(*(AttentionBlock(input_dimen, hiddenlayer_dimen, number_heads) for _ in range(transform_layers)))

self.dense_head = nn.Sequential(nn.LayerNorm(input_dimen), nn.Linear(input_dimen, predict_num))
final_num_patch = 1 + (240 // size_patch_x)*(256 // size_patch_y)
self.positional_emb = nn.Parameter(torch.randn(1,4*final_num_patch,input_dimen))
self.classification_tkn = nn.Parameter(torch.randn(1,1,input_dimen))


def forward(self, image):
# input being preprocessed
image = image_patcher(image, 16, 5)
Batch_Size, x, _ = image.shape

image = self.input_layer(image)

# Add a positional encoding and a CLS token
classification_tkn = self.classification_tkn.repeat(Batch_Size, 1, 1)
image = torch.cat([classification_tkn, image], dim=1)
image = image + self.positional_emb[:, : x + 1]

#this adds a final_transform
image = image.transpose(0, 1)
image = self.final_transform(image)
class_ = image[0]
out = self.dense_head(class_)
return out
29 changes: 29 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import matplotlib.pyplot as plt

#This is the accuracies of the first 100 epoch
accuracies = [0.4962, 0.5274, 0.5917, 0.4944, 0.5124, 0.5476, 0.5922, 0.5431, 0.5013, 0.5294, 0.5464, 0.5487, 0.5922, 0.6340, 0.5876, 0.6005, 0.6104, 0.6167, 0.6243, 0.6215, 0.6473, 0.6436, 0.5936, 0.6391, 0.6030, 0.6030, 0.5675, 0.6116, 0.6212, 0.5968, 0.5843, 0.6056, 0.6073, 0.6175, 0.5985, 0.5948, 0.6186, 0.5752, 0.6056, 0.5769, 0.6042, 0.6212, 0.5825, 0.5931, 0.5786, 0.6002, 0.5712, 0.5624, 0.5907, 0.5848, 0.6260, 0.6030, 0.5854, 0.5819, 0.6329, 0.6042, 0.6204, 0.6106, 0.6110, 0.6112, 0.6126, 0.6126, 0.6135, 0.6098, 0.6081, 0.6087, 0.6130, 0.6167, 0.6118, 0.6130, 0.6135, 0.6124, 0.6118, 0.6135, 0.6124, 0.6106, 0.6141, 0.6147, 0.6141, 0.6112, 0.6118, 0.6130, 0.6153, 0.6135, 0.6130, 0.6101, 0.6124, 0.6135, 0.6141, 0.6095, 0.6147, 0.6112, 0.6118, 0.6093, 0.6159, 0.6098, 0.6087, 0.6093, 0.6093, 0.6093]
#This is the loss for the first 100 epoch
loss = [0.6979881910716786, 0.6844412877279169, 0.6904780426446129, 0.684895454084172, 0.6812622880234438, 0.6792347308467416, 0.6771022852729348, 0.6855413440395804, 0.6720500413109275, 0.6870778799057007, 0.6519009260570302, 0.6404731641797459, 0.6227208665188622, 0.5900690169895396, 0.5951909887440064, 0.5607741247205174, 0.5638843687141643, 0.5399825283709694, 0.5193865106386297, 0.4769795151317821, 0.46142111543346853, 0.45512034174273996, 0.4230612568995532, 0.3895933969932444, 0.39457252358689027, 0.3760155246538274, 0.37903575160924124, 0.3169827342909925, 0.3307492321028429, 0.25006428689641114, 0.2171676978468895, 0.284952069468358, 0.19351636508808417, 0.19536656217978282, 0.1634677432696609, 0.11777753673274727, 0.14037292516406844, 0.23939100105096311, 0.11117816716432571, 0.06792027106070343, 0.11127064086715965, 0.08680430388845065, 0.08349954084876705, 0.0602918595815187, 0.04537964773172622, 0.018756276154068902, 0.017304050311555758, 0.0667554905932561, 0.056819359415813404, 0.01601847937426475, 0.0256724186885335, 0.08536267633248559, 0.016678674338275894, 0.021344836472588426, 0.03200960334951935, 0.054271318350562495, 0.032041940687443406, 0.008468315467539737, 0.0025479644400012843, 0.0009727750567596077, 0.0007114587334559902, 0.0006080651262035483, 0.0005404480839120772, 0.0004825884388992563, 0.00043984228036339013, 0.0004121263824773076, 0.00037888841025586077, 0.0003548304273007328, 0.0003334864805390894, 0.0003150697696529438, 0.0002981368049992906, 0.00028354384695001715, 0.00027058320034377496, 0.0002587148782742374, 0.00024809086993199717, 0.00023693855730337366, 0.0002282507754417191, 0.00022031878153725034, 0.00021815271654358024, 0.0002045452338814571, 0.00019706551778226103, 0.00019002843627651388, 0.0001844407169675619, 0.00017817121774629305, 0.00017225533241905985, 0.00016851070519324448, 0.00016327866144231795, 0.0001579179693448275, 0.00015318737795870917, 0.00014883789652444916, 0.00014545178911409013, 0.00014506131992675364, 0.00013773206635104383, 0.00013385336698026068, 0.00013058477484532085, 0.00012730956672492218, 0.00012449162686072455, 0.00012149684548871043, 0.0001185430414539844, 0.00011599305349484305]


epoch = list(range(100))


plt.figure(figsize=(12, 6))
plt.plot(epoch, accuracies, marker='o', linestyle='-', color='b', label='Accuracy')
plt.title('Accuracy vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()



plt.plot(range(len(loss)), loss, marker='o', linestyle='-', color='b', label='Loss')
plt.title('Loss vs Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid()
plt.show()
48 changes: 48 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import torch.nn as nn
from dataset import returnDataLoaders
from modules import *
import torch.optim as optim

dataloader_train, dataloader_test, dataloader_val = returnDataloaders()

losses = []
accuracies = []

def train(net, dataloader_train, dataloader_val, cross_entropy):
optimizer = optim.Adam(net.parameters(), lr=2e-4)
epochs = 100
# training loop
for epoch in range(epochs):
epoch_loss = 0
net.train()
for (x_batch, y_batch) in dataloader_train: # for each mini-batch
optimizer.zero_grad()
loss = cross_entropy(net.forward(x_batch), y_batch)
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_loss = epoch_loss / len(dataloader_train)
losses.append(epoch_loss)

net.eval()
acc = test(net, dataloader_val)
print("epoch:", epoch, "accuracy:", acc, "loss:", epoch_loss, flush=True)
accuracies.append(acc)

def test(net, dataloader_val, batch_size=16):
with torch.no_grad():
acc = 0
for (x_batch, y_batch) in dataloader_val:
acc += torch.sum((y_batch == torch.max(net(x_batch).detach(), 1)[1]), axis=0)/len(y_batch)
acc = acc/len(dataloader_val)
return acc

vit = VisionTransformer(input_dimen=128,
hiddenlayer_dimen=256,
number_heads=4,
transform_layers=4,
predict_num=2,
size_patch=(16,16))
cross_entropy = nn.CrossEntropyLoss()
train(vit, dataloader_train, dataloader_val, cross_entropy)