Skip to content

Commit

Permalink
remove bad config, use torchIO to produce train/test splits. visualiz…
Browse files Browse the repository at this point in the history
…e CT volumes in tests
  • Loading branch information
NevesLucas committed Oct 23, 2022
1 parent 8a89e9c commit 6a3eb19
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 43 deletions.
35 changes: 35 additions & 0 deletions TestDatasetLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.animation as animation

dataLoader = kaggleDataLoader.KaggleDataLoader()

Expand Down Expand Up @@ -40,4 +41,38 @@
fig, ax = plt.subplots(1,1, figsize=(10, 10))
ax.imshow(imSlicePxArr)
ax.add_patch(rect)
plt.show()

train, val = dataLoader.loadDatasetAsClassifier()

subject1 = train[0]

fig, ax = plt.subplots()
ims = []
for sagittal_slice_tensor in subject1.ct.data[0]:
im = ax.imshow(sagittal_slice_tensor.numpy(), cmap=plt.cm.bone, animated=True)
ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
repeat_delay=1000)
plt.show()

fig, ax = plt.subplots()
ims = []
for coronal_slice_tensor in subject1.ct.data[0].permute(1,2,0):
im = ax.imshow(coronal_slice_tensor.numpy(), cmap=plt.cm.bone, animated=True)
ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
repeat_delay=1000)
plt.show()

fig, ax = plt.subplots()
ims = []
for axial_slice_tensor in subject1.ct.data[0].permute(2,1,0):
im = ax.imshow(axial_slice_tensor.numpy(), cmap=plt.cm.bone, animated=True)
ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
repeat_delay=1000)
plt.show()
13 changes: 3 additions & 10 deletions environment-CPU.yml
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
# To run: conda env create -f environment-lnx-cuda.yml
# To run: conda env create -f environment-CPU.yml
name: csna
channels:
- conda-forge
dependencies:
- python=3.6
- python=3.9
- pip
- Pillow
- matplotlib
- numpy
- pandas
- pip:
- torch
- torchvision
- kaggle
- pydicom
- opencv-python
- nibabel
- -r requirements.txt
18 changes: 0 additions & 18 deletions environment-WIN.yml

This file was deleted.

39 changes: 34 additions & 5 deletions kaggleDataLoader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@

from PIL import Image

import copy
import pandas as pd
import pydicom
import nibabel as nib
import numpy as np
import json
import cv2
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from tqdm import tqdm
import random

import torchio as tio
import torch
with open('config.json', 'r') as f:
paths = json.load(f)

Expand Down Expand Up @@ -95,10 +96,38 @@ def loadSegmentationsForPatient(self, patientID):
return segmentations

## Dataset generator functions
def loadDatasetAsClassifier(self):
def loadDatasetAsClassifier(self, trainPercentage=0.90,train_aug=None,val_aug=None):
"""
prepare full dataset for training
"""

HOUNSFIELD_AIR, HOUNSFIELD_BONE = -1000, 1900
clamp = tio.Clamp(out_min=HOUNSFIELD_AIR, out_max=HOUNSFIELD_BONE)
rescale = tio.RescaleIntensity(percentiles=(0.5, 99.5))
preprocess_intensity = tio.Compose([
clamp,
rescale,
])
normalize_orientation = tio.ToCanonical()
downsample = tio.Resample(1)
preprocess_spatial = tio.Compose([
normalize_orientation,
downsample,
])
preprocess = tio.Compose([
preprocess_intensity,
preprocess_spatial,
])

trainSet = tio.datasets.RSNACervicalSpineFracture(RSNA_2022_PATH)
num_subjects = len(trainSet)
num_train = int(trainPercentage*num_subjects)
num_val = num_subjects - num_train
train_set, val_set = torch.utils.data.random_split(trainSet,[num_train,num_val])
train_set.dataset.set_transform(preprocess)
return train_set, val_set


def loadDatasetAsDetector(self):
"""
prepare full dataset for training
Expand Down
22 changes: 12 additions & 10 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
scipy==1.2.1
numpy==1.17.0
matplotlib==3.1.2
opencv_python==4.1.2.30
torch==1.2.0
torchvision==0.4.0
tqdm==4.60.0
Pillow==8.2.0
h5py==2.10.0
scipy
numpy
matplotlib
opencv_python
torch
torchvision
tqdm
Pillow
h5py
pylibjpeg
pylibjpeg-libjpeg
pylibjpeg-libjpeg
torchio
pydicom

0 comments on commit 6a3eb19

Please sign in to comment.