Skip to content

Commit

Permalink
working image cropping using unet segmentation (#4)
Browse files Browse the repository at this point in the history
* note for cropped training

* change sampling

* fix

* this doesnt work

* strip out bad subjects

* scc

* more bad scans found

* wrong augment order

* tolerance

* really high tolerance

* transform

* transform to CT

* reintroduce some subjects

* fix epoch counter

* disable patience

* oneHot encoding

* one hot is bad

* downsample

* smaller net

* invalid params

* AMP doesnt play nice with monai losses

* got amp working

* add early exit back

* add data augmentations

* missing dep

* typo

* lower learning rate, too erratic

* incrase val size

* no early stopping

* lets try VNet

* amp gives more vram, try larger unet

* sanity check

* bigger Unet

* sigmoid loss, add saving of intermediate inference visualizations

* add dice metric to tensorboard

* trying to get segmentation volumes

* print

* fix model loading

* moar fixing

* fix input to model

* fix model checkpointing

* use two channels for better segmentation

* exclude background

* back to one channel

* use oneHOT encoding for training

* disable AMP

* fix oneHot encoding

* transform val set too

* label mapping

* use softmax instead of sigmoid

* longer training, extra logging

* bugfix

* need more resolution, use in model downsampling

* stronger downsampling

* big model large strides

* input size adjustment

* fix metric

* more cores

* save every epoch

* Update segmenterTrainPytorch.py

* Update segmenterTrainPytorch.py

* use resizing for reversible masking

* metrics fix

* Update segmenterTrainPytorch.py

* dont forget to actually apply the resize

* cache the resize

* working image cropping with unet
  • Loading branch information
NevesLucas authored Dec 4, 2022
1 parent 742db52 commit 1b0cdee
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 34 deletions.
3 changes: 2 additions & 1 deletion ClassifierTrainPytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def competiton_loss_row_norm(y_hat, y):
train, val = dataset.loadDatasetAsSegmentor()


# TODO: use Segmentation ground truth data to crop train and val volumees into Regions of interest

# TODO: use Segmentation ground truth data to crop train and val volumes into Regions of interest

#train = CroppedROITrainSet
#val = CroppedROIValSet
Expand Down
66 changes: 53 additions & 13 deletions kaggleDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,53 @@ def __init__(self):
self.trainBbox = pd.read_csv(os.path.join(RSNA_2022_PATH, "train_bounding_boxes.csv"))
self.testDf = pd.read_csv(os.path.join(RSNA_2022_PATH, "test.csv"))
self.ss = pd.read_csv(os.path.join(RSNA_2022_PATH, "sample_submission.csv"))
# https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/discussion/344862
bad_scans = ['1.2.826.0.1.3680043.20574', '1.2.826.0.1.3680043.29952']
for uid in bad_scans:
self.trainDf.drop(self.trainDf[self.trainDf['StudyInstanceUID'] == uid].index, axis = 0, inplace = True)

#get the mappings for the data images and segmentations:
seg_paths = []
img_paths = []
UIDs = self.listTrainPatientID()
for uid in tqdm(UIDs):
seg_paths.append(os.path.join(self.segPath, str(uid)+".nii"))
img_paths.append(os.path.join(self.trainImagePath,str(uid)))

self.trainDf["seg_path"] = seg_paths
self.trainDf["img_paths"] = img_paths
self.trainDf.head()

def bboxFromIndex(self, id, sliceNum):
box = self.trainBbox.loc[(self.trainBbox.StudyInstanceUID == id) & (self.trainBbox.slice_number == sliceNum), :]
return list(box.values[0])

def fracturedBones(self, id):
fractured_bones = []
temp = self.trainDf.loc[self.trainDf.StudyInstanceUID == id, ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']]
temp = list(temp.values[0]) # there is one row per id
for i in range(len(temp)):
if temp[i] == 1:
fractured_bones.append('C' + str(i + 1))
return fractured_bones

def listTrainPatientID(self):
return list(self.trainDf["StudyInstanceUID"])

def listTestPatientID(self):
return list(self.testDf["StudyInstanceUID"])

def loadSliceImageFromId(self, patientID, sliceIndex):
imgPath = self.trainDf.loc[self.trainDf.StudyInstanceUID == patientID, "img_paths"]
imgPath = imgPath.iloc[0]
targetPath = os.path.join(imgPath, str(sliceIndex)+".dcm")
return loadDicom(targetPath)

def loadSegmentationsForPatient(self, patientID):
segmentations = nib.load(os.path.join(self.segPath, patientID+'.nii')).get_fdata()
segmentations = segmentations[:, ::-1, ::-1]
segmentations = segmentations.transpose(2, 1, 0)
return segmentations

## Dataset generator functions
def loadDatasetAsClassifier(self, trainPercentage=0.90, train_aug=None):
Expand All @@ -69,14 +116,10 @@ def loadDatasetAsClassifier(self, trainPercentage=0.90, train_aug=None):
rescale,
])
normalize_orientation = tio.ToCanonical()
downsample = tio.Resample(1)

cropOrPad = tio.CropOrPad((130,130,200))
preprocess_spatial = tio.Compose([
normalize_orientation,
downsample,
cropOrPad,
])
normalize_orientation])

preprocess = tio.Compose([
preprocess_spatial,
preprocess_intensity,
Expand Down Expand Up @@ -123,18 +166,14 @@ def loadDatasetAsSegmentor(self, trainPercentage=0.90, train_aug=None):
])
normalize_orientation = tio.ToCanonical()
transform = tio.Resample('ct')
downsample = tio.Resample(1)
cropOrPad = tio.CropOrPad((128,128,200))
preprocess_spatial = tio.Compose([
normalize_orientation,
downsample,
transform,
cropOrPad,
# downsample,
transform
])
sequential = tio.SequentialLabels()
remapping = {2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 11: 1}
remapping = {2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 11: 1, 12: 1, 13: 1}
remap_mask = tio.RemapLabels(remapping)

preprocess = tio.Compose([
sequential,
remap_mask,
Expand All @@ -151,6 +190,7 @@ def loadDatasetAsSegmentor(self, trainPercentage=0.90, train_aug=None):
num_val = num_subjects - num_train
train_set, val_set = torch.utils.data.random_split(trainSet,[num_train,num_val])
train_set.dataset.set_transform(preprocess)
val_set.dataset.set_transform(preprocess)
if train_aug is not None:
val_set = copy.deepcopy(val_set)
augment = tio.Compose([
Expand Down
97 changes: 97 additions & 0 deletions segmenterInferPytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import kaggleDataLoader
import json

from joblib import Memory
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
from monai.data import decollate_batch, DataLoader,Dataset,ImageDataset
from monai.metrics import DiceMetric
from monai.losses.dice import DiceLoss
from monai.networks.nets import BasicUNet
from monai.visualize import plot_2d_or_3d_image
from monai.transforms import AsDiscrete

from torchvision.ops import masks_to_boxes
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.animation as animation
import torch.cuda.amp as amp
import torchio as tio

with open('config.json', 'r') as f:
paths = json.load(f)

def boundingVolume(pred,original_dims):
#acquires the 3d bounding rectangular prism of the segmentation mask
indices = torch.nonzero(pred)
min_indices, min_val = indices.min(dim=0)
max_indices, max_val = indices.max(dim=0)
print(min_indices)
print(max_indices)
return (min_indices[1].item(), original_dims[0]-max_indices[1].item(),
min_indices[2].item(), original_dims[1]-max_indices[2].item(),
min_indices[3].item(), original_dims[2]-max_indices[3].item())


cachedir = paths["CACHE_DIR"]
modelWeights = paths["seg_weights"]

root_dir="./"

if torch.cuda.is_available():
print("GPU enabled")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = kaggleDataLoader.KaggleDataLoader()
train, val = dataset.loadDatasetAsClassifier(trainPercentage = 1.0)

model = torch.load(modelWeights, map_location=device)
model.eval()

resize = tio.Resize((128, 128, 200)) #resize for segmentation

basic_sample = train[10]
# get original dims first
original_dims = basic_sample.spatial_shape
downsampled = resize(basic_sample)

reverseTransform = tio.Resize(original_dims)

prediction = model(downsampled.ct.data.unsqueeze(0) ) #get mask for current subject

binary_mask = torch.argmax(prediction, dim=1) # binarize

fig, ax = plt.subplots()
ims = []
for sagittal_slice_tensor in binary_mask[0]:
im = ax.imshow(sagittal_slice_tensor.detach().numpy(), cmap=plt.cm.bone, animated=True)
ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
repeat_delay=1000)
plt.show()

binary_mask = reverseTransform(binary_mask) # convert mask back to original image resolution
bounding_prism = boundingVolume(binary_mask,original_dims) # find the bounding area of the segmentation

crop = tio.Crop(bounding_prism)
cropped_original = crop(basic_sample) # crop the original data to fit the segmentation.


fig, ax = plt.subplots()
ims = []
for sagittal_slice_tensor in cropped_original.ct.data[0]:
im = ax.imshow(sagittal_slice_tensor.detach().numpy(), animated=True)
ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
repeat_delay=1000)
plt.show()
45 changes: 25 additions & 20 deletions segmenterTrainPytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from monai.data import decollate_batch, DataLoader,Dataset,ImageDataset
from monai.metrics import DiceMetric
from monai.losses.dice import DiceLoss
from monai.networks.nets import BasicUNet

from monai.networks.nets import UNet, BasicUNet
from monai.networks.layers import Norm
from monai.visualize import plot_2d_or_3d_image
from monai.transforms import AsDiscrete

import torch.cuda.amp as amp
import torchio as tio
Expand All @@ -23,17 +26,21 @@

cachedir = paths["CACHE_DIR"]
memory = Memory(cachedir, verbose=0, compress=True)

resize = tio.Resize((128, 128, 200))
def cacheFunc(data, indexes):
return data[indexes]
return resize(data[indexes])

cacheFunc = memory.cache(cacheFunc)


oneHot = tio.OneHot()
flip = tio.RandomFlip(axes=('LR'))
aniso = tio.RandomAnisotropy()
noise = tio.RandomNoise()

augmentations = tio.Compose([flip,aniso,noise])
augmentations = tio.Compose([flip,aniso,noise,oneHot])
toDiscrete = AsDiscrete(argmax=True, to_onehot=2)


class cachingDataset(Dataset):

Expand Down Expand Up @@ -61,27 +68,29 @@ def __getitem__(self, idx):
train, batch_size=1, shuffle=True, prefetch_factor=4, persistent_workers=True, drop_last=True, num_workers=16)

val_loader = DataLoader(
val, batch_size=1, num_workers=8)

N_EPOCHS = 500
val, batch_size=1, num_workers=16)

N_EPOCHS = 300
model = BasicUNet(spatial_dims=3,
in_channels=1,
features=(32, 64, 128, 256, 512, 32),
out_channels=1).to(device)
out_channels=2).to(device)

optimizer = torch.optim.Adam(model.parameters(), 1e-5)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS)
scaler = amp.GradScaler()
loss = DiceLoss(sigmoid=True)
loss = DiceLoss(softmax=True)
val_interval = 1
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
dice_metric = DiceMetric(include_background=False, reduction="mean")

PATIENCE = 10

loss_hist = []
val_loss_hist = []
patience_counter = 0
best_val_loss = np.inf
batchCount = 0
#https://www.kaggle.com/code/samuelcortinhas/rnsa-3d-model-train-pytorch
writer = SummaryWriter()
#Loop over epochs
Expand All @@ -99,6 +108,7 @@ def __getitem__(self, idx):
imgs = batch['ct']['data']

labels = batch['seg']['data']

imgs = imgs.to(device)
labels = labels.to(device)

Expand Down Expand Up @@ -135,6 +145,7 @@ def __getitem__(self, idx):

# Forward pass
val_preds = model(val_imgs)
val_preds = toDiscrete(val_preds)
dice_metric(y_pred=val_preds, y=val_labels)
# Track loss
valid_count += 1
Expand All @@ -151,26 +162,20 @@ def __getitem__(self, idx):

#tensorboard logging
plot_2d_or_3d_image(val_imgs,epoch+1,writer,index=0,tag='image')
plot_2d_or_3d_image(val_labels,epoch+1,writer,index=0,tag='GT')
plot_2d_or_3d_image(val_preds,epoch+1,writer,index=0,tag='output')

# Print loss
if (epoch + 1) % 1 == 0:
print(
f'Epoch {epoch + 1}/{N_EPOCHS}, loss {loss_acc / train_count:.5f}, val_loss {val_loss_acc / valid_count:.5f}')
f'Epoch {epoch + 1}/{N_EPOCHS}, loss {loss_acc / train_count:.5f}, val_loss {metric:.5f}')

# Save model (& early stopping)
if (val_loss_acc / valid_count) < best_val_loss:
best_val_loss = val_loss_acc / valid_count
if (metric) < best_val_loss:
best_val_loss = metric
patience_counter = 0
print('Valid loss improved --> saving model')
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimiser_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss_acc / train_count,
'val_loss': val_loss_acc / valid_count,
}, "Unet3D.pt")
torch.save(model, str("Unet3D_resized_128x128x200"+str(epoch)+".pt"))

writer.close()
print('')
Expand Down

0 comments on commit 1b0cdee

Please sign in to comment.