Skip to content

Commit 9930ee1

Browse files
author
zhijian-yang
committed
Add SmileGAN Plugin
1 parent 3bdaf6b commit 9930ee1

File tree

4 files changed

+456
-0
lines changed

4 files changed

+456
-0
lines changed

NiBAx/plugins/SmileGAN/SmileGAN.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
from PyQt5.QtGui import *
2+
from matplotlib.backends.backend_qt5 import FigureCanvasQT
3+
from PyQt5 import QtGui, QtCore, QtWidgets, uic
4+
import joblib
5+
import sys, os, time
6+
import seaborn as sns
7+
import numpy as np
8+
import pandas as pd
9+
from NiBAx.core.plotcanvas import PlotCanvas
10+
from NiBAx.core.baseplugin import BasePlugin
11+
from NiBAx.core.gui.SearchableQComboBox import SearchableQComboBox
12+
13+
class computeSPAREs(QtWidgets.QWidget,BasePlugin):
14+
15+
#constructor
16+
def __init__(self):
17+
super(computeSPAREs,self).__init__()
18+
self.model = {'BrainAge': None, 'AD': None}
19+
root = os.path.dirname(__file__)
20+
self.readAdditionalInformation(root)
21+
self.ui = uic.loadUi(os.path.join(root, 'SmileGAN.ui'),self)
22+
self.ui.comboBoxHue = SearchableQComboBox(self.ui)
23+
self.ui.horizontalLayout_3.addWidget(self.comboBoxHue)
24+
self.plotCanvas = PlotCanvas(self.ui.page_2)
25+
self.ui.verticalLayout.addWidget(self.plotCanvas)
26+
self.plotCanvas.axes = self.plotCanvas.fig.add_subplot(111)
27+
self.SPAREs = None
28+
self.ui.stackedWidget.setCurrentIndex(0)
29+
self.ui.factorial_progressBar.setValue(0)
30+
31+
# Initialize thread
32+
self.thread = QtCore.QThread()
33+
34+
35+
def getUI(self):
36+
return self.ui
37+
38+
39+
def SetupConnections(self):
40+
#pass
41+
self.ui.load_SPARE_model_Btn.clicked.connect(lambda: self.OnLoadSPAREModel())
42+
self.ui.load_other_model_Btn.clicked.connect(lambda: self.OnLoadSPAREModel())
43+
self.ui.add_to_dataframe_Btn.clicked.connect(lambda: self.OnAddToDataFrame())
44+
self.ui.compute_SPARE_scores_Btn.clicked.connect(lambda check: self.OnComputeSPAREs(check))
45+
self.ui.show_SPARE_scores_from_data_Btn.clicked.connect(lambda: self.OnShowSPAREs())
46+
self.datamodel.data_changed.connect(lambda: self.OnDataChanged())
47+
self.ui.comboBoxHue.currentIndexChanged.connect(self.plotSPAREs)
48+
49+
self.ui.add_to_dataframe_Btn.setStyleSheet("background-color: green; color: white")
50+
# Set `Show SPARE-* from data` button to visible when SPARE-* columns
51+
# are present in data frame
52+
if ('SPARE_BA' in self.datamodel.GetColumnHeaderNames() and
53+
'SPARE_AD' in self.datamodel.GetColumnHeaderNames()):
54+
self.ui.show_SPARE_scores_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
55+
self.ui.show_SPARE_scores_from_data_Btn.setEnabled(True)
56+
self.ui.show_SPARE_scores_from_data_Btn.setToolTip('The data frame has variables `SPARE_AD` and `SPARE_BA` so these can be plotted.')
57+
else:
58+
self.ui.show_SPARE_scores_from_data_Btn.setEnabled(False)
59+
60+
# Allow loading of SPARE-* model always, even when residuals are not
61+
# calculated yet
62+
self.ui.load_SPARE_model_Btn.setEnabled(True)
63+
64+
65+
def updateProgress(self, txt, vl):
66+
self.ui.SPARE_computation_info.setText(txt)
67+
self.ui.factorial_progressBar.setValue(vl)
68+
69+
70+
def OnLoadSPAREModel(self):
71+
fileName, _ = QtWidgets.QFileDialog.getOpenFileName(None,
72+
'Open SPARE-* model file',
73+
QtCore.QDir().homePath(),
74+
"Pickle files (*.pkl.gz *.pkl)")
75+
if fileName != "":
76+
self.model['BrainAge'], self.model['AD'] = joblib.load(fileName)
77+
self.ui.compute_SPARE_scores_Btn.setEnabled(True)
78+
self.ui.SPARE_model_info.setText('File: %s' % (fileName))
79+
else:
80+
return
81+
82+
self.ui.stackedWidget.setCurrentIndex(0)
83+
84+
if 'RES_ICV_Sex_MUSE_Volume_47' in self.datamodel.GetColumnHeaderNames():
85+
self.ui.compute_SPARE_scores_Btn.setStyleSheet("QPushButton"
86+
"{"
87+
"background-color : rgb(230,255,230);"
88+
"}"
89+
"QPushButton::checked"
90+
"{"
91+
"background-color : rgb(255,230,230);"
92+
"border: none;"
93+
"}"
94+
)
95+
self.ui.compute_SPARE_scores_Btn.setEnabled(True)
96+
self.ui.compute_SPARE_scores_Btn.setChecked(False)
97+
self.ui.compute_SPARE_scores_Btn.setToolTip('Model loaded and `RES_ICV_Sex_MUSE_Volmue_*` available so the MUSE volumes can be harmonized.')
98+
else:
99+
self.ui.compute_SPARE_scores_Btn.setStyleSheet("background-color: rgb(255,230,230)")
100+
self.ui.compute_SPARE_scores_Btn.setEnabled(False)
101+
self.ui.compute_SPARE_scores_Btn.setToolTip('Model loaded but `RES_ICV_Sex_MUSE_Volmue_*` not available so the MUSE volumes can not be harmonized.')
102+
103+
104+
print('No field `RES_ICV_Sex_MUSE_Volume_47` found. ' +
105+
'Make sure to compute and add harmonized residuals first.')
106+
107+
108+
def OnComputationDone(self, y_hat):
109+
self.SPAREs = y_hat
110+
self.ui.compute_SPARE_scores_Btn.setText('Compute SPARE-*')
111+
if self.SPAREs.empty:
112+
return
113+
self.ui.compute_SPARE_scores_Btn.setChecked(False)
114+
self.ui.stackedWidget.setCurrentIndex(1)
115+
self.ui.comboBoxHue.setVisible(False)
116+
self.plotSPAREs(False)
117+
118+
# Activate buttons
119+
self.ui.compute_SPARE_scores_Btn.setEnabled(False)
120+
if ('SPARE_BA' in self.datamodel.GetColumnHeaderNames() and
121+
'SPARE_AD' in self.datamodel.GetColumnHeaderNames()):
122+
self.ui.show_SPARE_scores_from_data_Btn.setEnabled(True)
123+
#self.ui.show_SPARE_scores_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
124+
else:
125+
self.ui.show_SPARE_scores_from_data_Btn.setEnabled(False)
126+
self.ui.load_SPARE_model_Btn.setEnabled(True)
127+
128+
129+
130+
def OnComputeSPAREs(self, checked):
131+
# Setup tasks for long running jobs
132+
# Using this example: https://realpython.com/python-pyqt-qthread/
133+
# Disable buttons
134+
if checked is not True:
135+
self.thread.requestInterruption()
136+
else:
137+
self.ui.compute_SPARE_scores_Btn.setStyleSheet("QPushButton"
138+
"{"
139+
"background-color : rgb(230,255,230);"
140+
"}"
141+
"QPushButton::checked"
142+
"{"
143+
"background-color : rgb(255,230,230);"
144+
"}"
145+
)
146+
self.ui.compute_SPARE_scores_Btn.setText('Cancel computation')
147+
self.ui.show_SPARE_scores_from_data_Btn.setEnabled(False)
148+
self.ui.load_SPARE_model_Btn.setEnabled(False)
149+
self.thread = QtCore.QThread()
150+
self.worker = BrainAgeWorker(self.datamodel.data, self.model)
151+
self.worker.moveToThread(self.thread)
152+
self.thread.started.connect(self.worker.run)
153+
self.worker.done.connect(self.thread.quit)
154+
self.worker.done.connect(self.worker.deleteLater)
155+
self.thread.finished.connect(self.thread.deleteLater)
156+
self.worker.progress.connect(self.updateProgress)
157+
self.worker.done.connect(lambda y_hat: self.OnComputationDone(y_hat))
158+
self.ui.factorial_progressBar.setRange(0, len(self.model['BrainAge']['scaler'])-1)
159+
self.thread.start()
160+
161+
162+
def PopulateHue(self):
163+
#add the list items to comboBoxHue
164+
datakeys = self.datamodel.GetColumnHeaderNames()
165+
datatypes = self.datamodel.GetColumnDataTypes()
166+
categoryList = ['Sex','Study','A','T','N','PIB_Status'] + [k for k,d in zip(datakeys, datatypes) if d.name=='category']
167+
categoryList = list(set(categoryList).intersection(set(datakeys)))
168+
self.ui.comboBoxHue.clear()
169+
self.ui.comboBoxHue.addItems(categoryList)
170+
171+
172+
def plotSPAREs(self, useExistingSPAREs=True):
173+
# Plot data
174+
if self.ui.stackedWidget.currentIndex() == 0:
175+
return
176+
self.plotCanvas.axes.clear()
177+
plotOptions = {'HUE': self.ui.comboBoxHue.currentText()}
178+
179+
if useExistingSPAREs:
180+
print(self.SPAREs[plotOptions['HUE']].value_counts())
181+
kws = {"s": 20}
182+
sns.scatterplot(x='SPARE_AD', y='SPARE_BA', data=self.SPAREs,
183+
ax=self.plotCanvas.axes, linewidth=0, hue=plotOptions['HUE'],
184+
facecolor=(0.5, 0.5, 0.5, 0.5), **kws)
185+
else:
186+
kws = {"s": 20}
187+
sns.scatterplot(x='SPARE_AD', y='SPARE_BA', data=self.SPAREs,
188+
ax=self.plotCanvas.axes, linewidth=0,
189+
facecolor=(0.5, 0.5, 0.5, 0.5), legend=None)
190+
191+
192+
193+
sns.despine(ax=self.plotCanvas.axes, trim=True)
194+
self.plotCanvas.axes.set(ylabel='SPARE-BA', xlabel='SPARE-AD')
195+
self.plotCanvas.axes.get_figure().set_tight_layout(True)
196+
self.plotCanvas.draw()
197+
198+
199+
def OnAddToDataFrame(self):
200+
print('Adding SPARE-* scores to data frame...')
201+
self.datamodel.data.loc[:,'SPARE_AD'] = self.SPAREs['SPARE_AD']
202+
self.datamodel.data.loc[:,'SPARE_BA'] = self.SPAREs['SPARE_BA']
203+
self.datamodel.data_changed.emit()
204+
self.OnShowSPAREs()
205+
206+
207+
def OnShowSPAREs(self):
208+
allHues = [self.ui.comboBoxHue.itemText(i) for i in range(self.ui.comboBoxHue.count())]
209+
210+
self.SPAREs = self.datamodel.data[['SPARE_BA', 'SPARE_AD'] + allHues]
211+
self.ui.stackedWidget.setCurrentIndex(1)
212+
self.ui.comboBoxHue.setVisible(True)
213+
self.plotSPAREs()
214+
215+
216+
def OnDataChanged(self):
217+
# Set `Show SPARE-* from data` button to visible when SPARE-* columns
218+
# are present in data frame
219+
self.ui.stackedWidget.setCurrentIndex(0)
220+
if ('SPARE_BA' in self.datamodel.GetColumnHeaderNames() and
221+
'SPARE_AD' in self.datamodel.GetColumnHeaderNames()):
222+
self.ui.show_SPARE_scores_from_data_Btn.setEnabled(True)
223+
self.ui.show_SPARE_scores_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
224+
else:
225+
self.ui.show_SPARE_scores_from_data_Btn.setEnabled(False)
226+
227+
self.PopulateHue()
228+
229+
230+
class BrainAgeWorker(QtCore.QObject):
231+
232+
done = QtCore.pyqtSignal(pd.DataFrame)
233+
progress = QtCore.pyqtSignal(str, int)
234+
235+
#constructor
236+
def __init__(self, data, model):
237+
super(BrainAgeWorker, self).__init__()
238+
self.data = data
239+
self.model = model
240+
241+
def run(self):
242+
y_hat = pd.DataFrame.from_dict({'SPARE_BA': np.full((self.data.shape[0],),np.nan),
243+
'SPARE_AD': np.full((self.data.shape[0],),np.nan)})
244+
245+
# SPARE-BA
246+
idx = ~self.data[self.model['BrainAge']['predictors'][0]].isnull()
247+
248+
y_hat_test = np.zeros((np.sum(idx),))
249+
n_ensembles = np.zeros((np.sum(idx),))
250+
251+
for i,_ in enumerate(self.model['BrainAge']['scaler']):
252+
# Predict validation (fold) and test
253+
self.progress.emit('Computing SPARE-BA | Task 1 of 2', i)
254+
if QtCore.QThread.currentThread().isInterruptionRequested():
255+
self.progress.emit('Cancelled.', 0)
256+
self.done.emit(pd.DataFrame())
257+
return
258+
test = np.logical_not(self.data[idx]['participant_id'].isin(np.concatenate(self.model['BrainAge']['train']))) | self.data[idx]['participant_id'].isin(self.model['BrainAge']['validation'][i])
259+
X = self.data[idx].loc[test, self.model['BrainAge']['predictors']].values
260+
X = self.model['BrainAge']['scaler'][i].transform(X)
261+
y_hat_test[test] += (self.model['BrainAge']['svm'][i].predict(X) - self.model['BrainAge']['bias_ints'][i]) / self.model['BrainAge']['bias_slopes'][i]
262+
n_ensembles[test] += 1.
263+
264+
y_hat_test /= n_ensembles
265+
y_hat.loc[idx, 'SPARE_BA'] = y_hat_test
266+
267+
idx = ~self.data[self.model['AD']['predictors'][0]].isnull()
268+
269+
y_hat_test = np.zeros((np.sum(idx),))
270+
n_ensembles = np.zeros((np.sum(idx),))
271+
272+
for i,_ in enumerate(self.model['AD']['scaler']):
273+
# Predict validation (fold) and test
274+
self.progress.emit('Computing SPARE-AD | Task 2 of 2', i)
275+
if QtCore.QThread.currentThread().isInterruptionRequested():
276+
self.progress.emit('Cancelled.', 0)
277+
# Emit the result
278+
self.done.emit(pd.DataFrame())
279+
return
280+
test = np.logical_not(self.data[idx]['participant_id'].isin(np.concatenate(self.model['AD']['train']))) | self.data[idx]['participant_id'].isin(self.model['AD']['validation'][i])
281+
X = self.data[idx].loc[test, self.model['AD']['predictors']].values
282+
X = self.model['AD']['scaler'][i].transform(X)
283+
y_hat_test[test] += self.model['AD']['svm'][i].decision_function(X)
284+
n_ensembles[test] += 1.
285+
286+
y_hat_test /= n_ensembles
287+
y_hat.loc[idx, 'SPARE_AD'] = y_hat_test
288+
289+
self.progress.emit('All done.', i)
290+
291+
# Emit the result
292+
self.done.emit(y_hat)

0 commit comments

Comments
 (0)