-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* note for cropped training * change sampling * fix * this doesnt work * strip out bad subjects * scc * more bad scans found * wrong augment order * tolerance * really high tolerance * transform * transform to CT * reintroduce some subjects * fix epoch counter * disable patience * oneHot encoding * one hot is bad * downsample * smaller net * invalid params * AMP doesnt play nice with monai losses * got amp working * add early exit back * add data augmentations * missing dep * typo * lower learning rate, too erratic * incrase val size * no early stopping * lets try VNet * amp gives more vram, try larger unet * sanity check * bigger Unet * sigmoid loss, add saving of intermediate inference visualizations * add dice metric to tensorboard * trying to get segmentation volumes * print * fix model loading * moar fixing * fix input to model * fix model checkpointing * use two channels for better segmentation * exclude background * back to one channel * use oneHOT encoding for training * disable AMP * fix oneHot encoding * transform val set too * label mapping * use softmax instead of sigmoid * longer training, extra logging * bugfix * need more resolution, use in model downsampling * stronger downsampling * big model large strides * input size adjustment * fix metric * more cores * save every epoch * Update segmenterTrainPytorch.py * Update segmenterTrainPytorch.py * use resizing for reversible masking * metrics fix * Update segmenterTrainPytorch.py * dont forget to actually apply the resize * cache the resize * working image cropping with unet * segmenter data cropping training pipeline * Update ClassifierTrainPytorch.py * Update .gitignore * use 2 gpus * put data on the right device * two gpus too difficult, just use cpu processing * Update ClassifierTrainPytorch.py * switch to better tboard reporting tool * Update ClassifierTrainPytorch.py * Update ClassifierTrainPytorch.py * Update ClassifierTrainPytorch.py * metrics+denseNet201 * Update ClassifierTrainPytorch.py * loading dataset wrong * Update ClassifierTrainPytorch.py * Update ClassifierTrainPytorch.py * fp16 validation too * Update ClassifierTrainPytorch.py * Create ClassifierTrainPytorch2.py * add crop conversion * Update DatasetLoaderConvertData.py * swap to using pre-cropped data * fix imports * Update ClassifierTrainPytorch2.py * lower batchsize * Update ClassifierTrainPytorch2.py * multi-gpu_test * Update ClassifierTrainPytorch.py * Update ClassifierTrainPytorch.py * Update ClassifierTrainPytorch.py * perf stats evaluation * Update ClassifierTrainPytorch.py * Update ClassifierInferPytorch.py * output model stats to file * better stats reporting * Update ClassifierInferPytorch.py * Update ClassifierInferPytorch.py * Update ClassifierInferPytorch.py * Update ClassifierInferPytorch.py * Update ClassifierInferPytorch.py * Update ClassifierTrainPytorch.py
- Loading branch information
1 parent
1b0cdee
commit dfa46c1
Showing
11 changed files
with
667 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -132,3 +132,9 @@ dmypy.json | |
# pycharm | ||
.idea | ||
*.json | ||
|
||
|
||
#tensorboard | ||
runs/ | ||
#artifacts | ||
*.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import json | ||
import torch | ||
import torch.nn as nn | ||
import torchio as tio | ||
import pandas as pd | ||
from rsna_cropped import RSNACervicalSpineFracture | ||
import plotly.graph_objects as go | ||
from tqdm import tqdm | ||
from sklearn.metrics import classification_report,roc_curve | ||
with open('config.json', 'r') as f: | ||
paths = json.load(f) | ||
|
||
if torch.cuda.is_available(): | ||
print("GPU enabled") | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
RSNA_2022_PATH = paths["RSNA_2022_PATH"] | ||
cachedir = paths["CACHE_DIR"] | ||
classWeights = paths["classifier_weights"] | ||
classModel = torch.load(classWeights, map_location=device) | ||
classModel.eval() | ||
|
||
|
||
pred_cols = [ | ||
"C1", | ||
"C2", | ||
"C3", | ||
"C4", | ||
"C5", | ||
"C6", | ||
"C7", | ||
"patient_overall" | ||
] | ||
|
||
root_dir="./" | ||
def column(matrix, i): | ||
return [row[i] for row in matrix] | ||
|
||
#trainSet = tio.datasets.RSNACervicalSpineFracture(RSNA_2022_PATH, add_segmentations=False) | ||
trainSet = RSNACervicalSpineFracture(RSNA_2022_PATH, add_segmentations=False) # pre-cropped data | ||
with torch.no_grad(): | ||
predicted_logits = [] | ||
actual = [] | ||
|
||
for classifier_input in tqdm(trainSet): | ||
# get original dims first | ||
#classifier_input = preprocess(samples) | ||
logits = classModel(classifier_input.ct.data.unsqueeze(0).to(device)).cpu()[0] | ||
gt = [classifier_input[target_col] for target_col in pred_cols] | ||
sig = nn.Sigmoid() | ||
preds = sig(logits) | ||
overall = preds.numpy().squeeze() | ||
predicted_logits.append(overall) | ||
actual.append(gt) | ||
|
||
scatterPlots = [] | ||
for i in range(0,len(pred_cols)): | ||
fpr, tpr, thresholds = roc_curve(column(actual, i), column(predicted_logits, i)) | ||
scatterPlots.append(go.Scatter3d( | ||
x=fpr, | ||
y=tpr, | ||
z=thresholds, | ||
name=pred_cols[i], | ||
showlegend=True, | ||
marker=dict( | ||
size=5 | ||
), | ||
line=dict( | ||
width=2) | ||
)) | ||
fig = go.Figure(data=scatterPlots) | ||
fig.update_layout(scene=dict( | ||
xaxis_title='False Positive Rate', | ||
yaxis_title='True Positive Rate', | ||
zaxis_title='Threshold'), | ||
width=1920, | ||
height=1080, | ||
margin=dict(r=20, b=10, l=10, t=10)) | ||
|
||
fig.write_html("classifier_roc_plot.html") | ||
fig.show() | ||
print("choose thresholds for report") | ||
thresholds = [] | ||
for label in pred_cols: | ||
print(label) | ||
ele = float(input()) | ||
thresholds.append(ele) # adding the element | ||
|
||
predicted = [[(ele > threshold)*1 for ele,threshold in zip(element, thresholds)] for element in predicted_logits] | ||
report = classification_report(predicted, actual, output_dict=True, | ||
target_names=pred_cols) | ||
|
||
df = pd.DataFrame(report).transpose() | ||
df.to_csv("modelReport.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.