Skip to content

Commit

Permalink
updates from SCC workflow (#5)
Browse files Browse the repository at this point in the history
* note for cropped training

* change sampling

* fix

* this doesnt work

* strip out bad subjects

* scc

* more bad scans found

* wrong augment order

* tolerance

* really high tolerance

* transform

* transform to CT

* reintroduce some subjects

* fix epoch counter

* disable patience

* oneHot encoding

* one hot is bad

* downsample

* smaller net

* invalid params

* AMP doesnt play nice with monai losses

* got amp working

* add early exit back

* add data augmentations

* missing dep

* typo

* lower learning rate, too erratic

* incrase val size

* no early stopping

* lets try VNet

* amp gives more vram, try larger unet

* sanity check

* bigger Unet

* sigmoid loss, add saving of intermediate inference visualizations

* add dice metric to tensorboard

* trying to get segmentation volumes

* print

* fix model loading

* moar fixing

* fix input to model

* fix model checkpointing

* use two channels for better segmentation

* exclude background

* back to one channel

* use oneHOT encoding for training

* disable AMP

* fix oneHot encoding

* transform val set too

* label mapping

* use softmax instead of sigmoid

* longer training, extra logging

* bugfix

* need more resolution, use in model downsampling

* stronger downsampling

* big model large strides

* input size adjustment

* fix metric

* more cores

* save every epoch

* Update segmenterTrainPytorch.py

* Update segmenterTrainPytorch.py

* use resizing for reversible masking

* metrics fix

* Update segmenterTrainPytorch.py

* dont forget to actually apply the resize

* cache the resize

* working image cropping with unet

* segmenter data cropping training pipeline

* Update ClassifierTrainPytorch.py

* Update .gitignore

* use 2 gpus

* put data on the right device

* two gpus too difficult, just use cpu processing

* Update ClassifierTrainPytorch.py

* switch to better tboard reporting tool

* Update ClassifierTrainPytorch.py

* Update ClassifierTrainPytorch.py

* Update ClassifierTrainPytorch.py

* metrics+denseNet201

* Update ClassifierTrainPytorch.py

* loading dataset wrong

* Update ClassifierTrainPytorch.py

* Update ClassifierTrainPytorch.py

* fp16 validation too

* Update ClassifierTrainPytorch.py

* Create ClassifierTrainPytorch2.py

* add crop conversion

* Update DatasetLoaderConvertData.py

* swap to using pre-cropped data

* fix imports

* Update ClassifierTrainPytorch2.py

* lower batchsize

* Update ClassifierTrainPytorch2.py

* multi-gpu_test

* Update ClassifierTrainPytorch.py

* Update ClassifierTrainPytorch.py

* Update ClassifierTrainPytorch.py

* perf stats evaluation

* Update ClassifierTrainPytorch.py

* Update ClassifierInferPytorch.py

* output model stats to file

* better stats reporting

* Update ClassifierInferPytorch.py

* Update ClassifierInferPytorch.py

* Update ClassifierInferPytorch.py

* Update ClassifierInferPytorch.py

* Update ClassifierInferPytorch.py

* Update ClassifierTrainPytorch.py
  • Loading branch information
NevesLucas authored Dec 14, 2022
1 parent 1b0cdee commit dfa46c1
Show file tree
Hide file tree
Showing 11 changed files with 667 additions and 153 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,9 @@ dmypy.json
# pycharm
.idea
*.json


#tensorboard
runs/
#artifacts
*.pt
94 changes: 94 additions & 0 deletions ClassifierInferPytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json
import torch
import torch.nn as nn
import torchio as tio
import pandas as pd
from rsna_cropped import RSNACervicalSpineFracture
import plotly.graph_objects as go
from tqdm import tqdm
from sklearn.metrics import classification_report,roc_curve
with open('config.json', 'r') as f:
paths = json.load(f)

if torch.cuda.is_available():
print("GPU enabled")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RSNA_2022_PATH = paths["RSNA_2022_PATH"]
cachedir = paths["CACHE_DIR"]
classWeights = paths["classifier_weights"]
classModel = torch.load(classWeights, map_location=device)
classModel.eval()


pred_cols = [
"C1",
"C2",
"C3",
"C4",
"C5",
"C6",
"C7",
"patient_overall"
]

root_dir="./"
def column(matrix, i):
return [row[i] for row in matrix]

#trainSet = tio.datasets.RSNACervicalSpineFracture(RSNA_2022_PATH, add_segmentations=False)
trainSet = RSNACervicalSpineFracture(RSNA_2022_PATH, add_segmentations=False) # pre-cropped data
with torch.no_grad():
predicted_logits = []
actual = []

for classifier_input in tqdm(trainSet):
# get original dims first
#classifier_input = preprocess(samples)
logits = classModel(classifier_input.ct.data.unsqueeze(0).to(device)).cpu()[0]
gt = [classifier_input[target_col] for target_col in pred_cols]
sig = nn.Sigmoid()
preds = sig(logits)
overall = preds.numpy().squeeze()
predicted_logits.append(overall)
actual.append(gt)

scatterPlots = []
for i in range(0,len(pred_cols)):
fpr, tpr, thresholds = roc_curve(column(actual, i), column(predicted_logits, i))
scatterPlots.append(go.Scatter3d(
x=fpr,
y=tpr,
z=thresholds,
name=pred_cols[i],
showlegend=True,
marker=dict(
size=5
),
line=dict(
width=2)
))
fig = go.Figure(data=scatterPlots)
fig.update_layout(scene=dict(
xaxis_title='False Positive Rate',
yaxis_title='True Positive Rate',
zaxis_title='Threshold'),
width=1920,
height=1080,
margin=dict(r=20, b=10, l=10, t=10))

fig.write_html("classifier_roc_plot.html")
fig.show()
print("choose thresholds for report")
thresholds = []
for label in pred_cols:
print(label)
ele = float(input())
thresholds.append(ele) # adding the element

predicted = [[(ele > threshold)*1 for ele,threshold in zip(element, thresholds)] for element in predicted_logits]
report = classification_report(predicted, actual, output_dict=True,
target_names=pred_cols)

df = pd.DataFrame(report).transpose()
df.to_csv("modelReport.csv")
120 changes: 63 additions & 57 deletions ClassifierTrainPytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,34 @@
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
from monai.data import decollate_batch, DataLoader,Dataset,ImageDataset
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121
from sklearn.metrics import classification_report
import torch.cuda.amp as amp
import torchio as tio

with open('config.json', 'r') as f:
paths = json.load(f)

segWeights = paths["seg_weights"]
cachedir = paths["CACHE_DIR"]
memory = Memory(cachedir, verbose=0, compress=True)

def cacheFunc(data, indexes):

return data[indexes]

cacheFunc = memory.cache(cacheFunc)

flip = tio.RandomFlip()
affine = tio.RandomAffine()
gamma = tio.RandomGamma(0.5)
aniso = tio.RandomAnisotropy(p=0.25)
noise = tio.RandomNoise(p=0.25)
augmentations = tio.Compose([flip, affine, aniso, noise, gamma])

class cachingDataset(Dataset):

def __init__(self, data):
Expand All @@ -32,7 +44,8 @@ def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
return cacheFunc(self.dataset,idx)
batch = cacheFunc(self.dataset, idx)
return augmentations(batch)


# Replicate competition metric (https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/discussion/341854)
Expand All @@ -47,10 +60,7 @@ def __getitem__(self, idx):
'C4', 'C5', 'C6', 'C7',
'patient_overall']


# Replicate competition metric (https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/discussion/341854)
loss_fn = nn.BCEWithLogitsLoss(reduction='none')

competition_weights = {
'-' : torch.tensor([1, 1, 1, 1, 1, 1, 1, 7], dtype=torch.float, device=device),
'+' : torch.tensor([2, 2, 2, 2, 2, 2, 2, 14], dtype=torch.float, device=device),
Expand All @@ -61,45 +71,45 @@ def __getitem__(self, idx):

# with row-wise weights normalization (https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/discussion/344565)
def competiton_loss_row_norm(y_hat, y):
loss = loss_fn(y_hat, y.to(y_hat.dtype))
loss = loss_fn(y_hat, y)
weights = y * competition_weights['+'] + (1 - y) * competition_weights['-']
loss = (loss * weights).sum(axis=1)
w_sum = weights.sum(axis=1)
loss = torch.div(loss, w_sum)
return loss.mean()

dataset = kaggleDataLoader.KaggleDataLoader()
train, val = dataset.loadDatasetAsSegmentor()



# TODO: use Segmentation ground truth data to crop train and val volumes into Regions of interest

#train = CroppedROITrainSet
#val = CroppedROIValSet
train, val = dataset.loadDatasetAsClassifier()

train = cachingDataset(train)
val = cachingDataset(val)

train_loader = DataLoader(
train, batch_size=4, shuffle=True, prefetch_factor=4, persistent_workers=True, drop_last=True, num_workers=16)
train, batch_size=16, shuffle=True, prefetch_factor=16, persistent_workers=True, drop_last=True, num_workers=32)

val_loader = DataLoader(
val, batch_size=1, num_workers=8)
val, batch_size=8, num_workers=32)

# train_loader = DataLoader(
# train, batch_size=1, shuffle=True, num_workers=0)
# val_loader = DataLoader(
# val, batch_size=1, num_workers=0)

n_epochs = 10
N_EPOCHS = 500
model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
val_interval = 1
auc_metric = ROCAUCMetric()
model = nn.DataParallel(model)
model.to(device)

N_EPOCHS = 20
PATIENCE = 3
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS)
scaler = amp.GradScaler()

val_interval = 1
loss_hist = []
val_loss_hist = []
patience_counter = 0
best_val_loss = np.inf
writer = SummaryWriter()
#https://www.kaggle.com/code/samuelcortinhas/rnsa-3d-model-train-pytorch
#Loop over epochs
for epoch in tqdm(range(N_EPOCHS)):
Expand All @@ -118,21 +128,27 @@ def competiton_loss_row_norm(y_hat, y):
labels = labels.to(device)

# Forward pass
preds = model(imgs)
L = competiton_loss_row_norm(preds, labels)
with amp.autocast(dtype=torch.float16):
preds = model(imgs)
L = competiton_loss_row_norm(preds, labels)

# Backprop
L.backward()
# Update parameters
optimizer.step()

# Zero gradients
scaler.scale(L).backward()
scaler.step(optimizer)
scaler.update()

# # Backprop
# L.backward()
# # Update parameters
# optimizer.step()
# #
# # Zero gradients
optimizer.zero_grad()

# Track loss
#Track loss
loss_acc += L.detach().item()
train_count += 1
print("finished batch")
print("finished batch " + str(train_count))
# Update learning rate
scheduler.step()

Expand All @@ -148,55 +164,45 @@ def competiton_loss_row_norm(y_hat, y):
val_labels = val_labels.to(device)

# Forward pass
val_preds = model(val_imgs)
val_L = competiton_loss_row_norm(val_preds, val_labels)
with amp.autocast(dtype=torch.float16):
val_preds = model(val_imgs)
val_L = competiton_loss_row_norm(val_preds, val_labels)

# Track loss
val_loss_acc += val_L.item()
valid_count += 1
print("finished validation batch")

# Save loss history
loss_hist.append(loss_acc / train_count)
val_loss_hist.append(val_loss_acc / valid_count)
# Save loss history
loss_hist.append(loss_acc / train_count)
val_loss_hist.append(val_loss_acc / valid_count)

writer.add_scalar("train_loss", loss_acc / train_count,epoch + 1)
writer.add_scalar("val_loss", val_loss_acc / valid_count, epoch + 1)

# Print loss
if (epoch + 1) % 1 == 0:
print(
f'Epoch {epoch + 1}/{N_EPOCHS}, loss {loss_acc / train_count:.5f}, val_loss {val_loss_acc / valid_count:.5f}')

# Save model (& early stopping)
if (val_loss_acc / valid_count) < best_val_loss:
best_val_loss = val_loss_acc / valid_count
patience_counter = 0
print('Valid loss improved --> saving model')
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimiser_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss_acc / train_count,
'val_loss': val_loss_acc / valid_count,
}, "Conv3DNet.pt")
else:
patience_counter += 1

if patience_counter == PATIENCE:
break
torch.save(model, str("classifier_dist_DenseNet121_" + str(epoch)+".pt"))

writer.close()
print('')
print('Training complete!')
# log loss
data = {'val_loss':val_loss_hist,'loss':loss_hist}
df = pd.DataFrame(data=data)
df.to_csv("results.csv", sep='\t')
df.to_csv("train_log_densenet121.csv", sep='\t')

# Plot loss
plt.figure(figsize=(10,5))
plt.figure(figsize=(10, 5))
plt.plot(loss_hist, c='C0', label='loss')
plt.plot(val_loss_hist, c='C1', label='val_loss')
plt.title('Competition metric')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig("train_result.png")
plt.savefig("train_result_densenet121.png")
plt.show()
Loading

0 comments on commit dfa46c1

Please sign in to comment.