-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathClassifierInferPytorch.py
94 lines (83 loc) · 2.73 KB
/
ClassifierInferPytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
import torch
import torch.nn as nn
import torchio as tio
import pandas as pd
from rsna_cropped import RSNACervicalSpineFracture
import plotly.graph_objects as go
from tqdm import tqdm
from sklearn.metrics import classification_report,roc_curve
with open('config.json', 'r') as f:
paths = json.load(f)
if torch.cuda.is_available():
print("GPU enabled")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RSNA_2022_PATH = paths["RSNA_2022_PATH"]
cachedir = paths["CACHE_DIR"]
classWeights = paths["classifier_weights"]
classModel = torch.load(classWeights, map_location=device)
classModel.eval()
pred_cols = [
"C1",
"C2",
"C3",
"C4",
"C5",
"C6",
"C7",
"patient_overall"
]
root_dir="./"
def column(matrix, i):
return [row[i] for row in matrix]
#trainSet = tio.datasets.RSNACervicalSpineFracture(RSNA_2022_PATH, add_segmentations=False)
trainSet = RSNACervicalSpineFracture(RSNA_2022_PATH, add_segmentations=False) # pre-cropped data
with torch.no_grad():
predicted_logits = []
actual = []
for classifier_input in tqdm(trainSet):
# get original dims first
#classifier_input = preprocess(samples)
logits = classModel(classifier_input.ct.data.unsqueeze(0).to(device)).cpu()[0]
gt = [classifier_input[target_col] for target_col in pred_cols]
sig = nn.Sigmoid()
preds = sig(logits)
overall = preds.numpy().squeeze()
predicted_logits.append(overall)
actual.append(gt)
scatterPlots = []
for i in range(0,len(pred_cols)):
fpr, tpr, thresholds = roc_curve(column(actual, i), column(predicted_logits, i))
scatterPlots.append(go.Scatter3d(
x=fpr,
y=tpr,
z=thresholds,
name=pred_cols[i],
showlegend=True,
marker=dict(
size=5
),
line=dict(
width=2)
))
fig = go.Figure(data=scatterPlots)
fig.update_layout(scene=dict(
xaxis_title='False Positive Rate',
yaxis_title='True Positive Rate',
zaxis_title='Threshold'),
width=1920,
height=1080,
margin=dict(r=20, b=10, l=10, t=10))
fig.write_html("classifier_roc_plot.html")
fig.show()
print("choose thresholds for report")
thresholds = []
for label in pred_cols:
print(label)
ele = float(input())
thresholds.append(ele) # adding the element
predicted = [[(ele > threshold)*1 for ele,threshold in zip(element, thresholds)] for element in predicted_logits]
report = classification_report(predicted, actual, output_dict=True,
target_names=pred_cols)
df = pd.DataFrame(report).transpose()
df.to_csv("modelReport.csv")