Skip to content

Commit 8d79869

Browse files
author
zhijian-yang
committed
Add SmileGAN Plugin
1 parent 3bdaf6b commit 8d79869

File tree

5 files changed

+463
-0
lines changed

5 files changed

+463
-0
lines changed

NiBAx/plugins/SmileGAN/SmileGAN.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
from PyQt5.QtGui import *
2+
from matplotlib.backends.backend_qt5 import FigureCanvasQT
3+
from PyQt5 import QtGui, QtCore, QtWidgets, uic
4+
import joblib
5+
import sys, os, time
6+
import seaborn as sns
7+
import numpy as np
8+
import pandas as pd
9+
from NiBAx.core.plotcanvas import PlotCanvas
10+
from NiBAx.core.baseplugin import BasePlugin
11+
from NiBAx.core.gui.SearchableQComboBox import SearchableQComboBox
12+
from SmileGAN.Smile_GAN_clustering import clustering_result
13+
14+
class computeSmileGANs(QtWidgets.QWidget,BasePlugin):
15+
16+
#constructor
17+
def __init__(self):
18+
super(computeSmileGANs,self).__init__()
19+
self.model = []
20+
root = os.path.dirname(__file__)
21+
self.readAdditionalInformation(root)
22+
self.ui = uic.loadUi(os.path.join(root, 'SmileGAN.ui'),self)
23+
self.ui.comboBoxHue = SearchableQComboBox(self.ui)
24+
self.ui.horizontalLayout_3.addWidget(self.comboBoxHue)
25+
self.plotCanvas = PlotCanvas(self.ui.page_2)
26+
self.ui.verticalLayout.addWidget(self.plotCanvas)
27+
self.plotCanvas.axes = self.plotCanvas.fig.add_subplot(111)
28+
self.SPAREs = None
29+
self.ui.stackedWidget.setCurrentIndex(0)
30+
self.ui.factorial_progressBar.setValue(0)
31+
32+
# Initialize thread
33+
self.thread = QtCore.QThread()
34+
35+
36+
def getUI(self):
37+
return self.ui
38+
39+
40+
def SetupConnections(self):
41+
#pass
42+
self.ui.load_SmileGAN_model_Btn.clicked.connect(lambda: self.OnLoadSmileGANModel())
43+
self.ui.load_other_model_Btn.clicked.connect(lambda: self.OnLoadSmileGANModel())
44+
self.ui.add_to_dataframe_Btn.clicked.connect(lambda: self.OnAddToDataFrame())
45+
self.ui.compute_SmileGAN_Btn.clicked.connect(lambda check: self.OnComputeSPAREs(check))
46+
self.ui.show_SmileGAN_prob_from_data_Btn.clicked.connect(lambda: self.OnShowSPAREs())
47+
self.datamodel.data_changed.connect(lambda: self.OnDataChanged())
48+
self.ui.comboBoxHue.currentIndexChanged.connect(self.plotSPAREs)
49+
50+
self.ui.add_to_dataframe_Btn.setStyleSheet("background-color: green; color: white")
51+
# Set `Show SPARE-* from data` button to visible when SPARE-* columns
52+
# are present in data frame
53+
if ('SPARE_BA' in self.datamodel.GetColumnHeaderNames() and
54+
'SPARE_AD' in self.datamodel.GetColumnHeaderNames()):
55+
self.ui.show_SmileGAN_prob_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
56+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
57+
self.ui.show_SmileGAN_prob_from_data_Btn.setToolTip('The data frame has variables `SPARE_AD` and `SPARE_BA` so these can be plotted.')
58+
else:
59+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
60+
61+
# Allow loading of SPARE-* model always, even when residuals are not
62+
# calculated yet
63+
self.ui.load_SmileGAN_model_Btn.setEnabled(True)
64+
65+
66+
def updateProgress(self, txt, vl):
67+
self.ui.SPARE_computation_info.setText(txt)
68+
self.ui.factorial_progressBar.setValue(vl)
69+
70+
71+
def OnLoadSmileGANModel(self):
72+
fileNames, _ = QtWidgets.QFileDialog.getOpenFileNames(None,
73+
'Open SPARE-* model file',
74+
QtCore.QDir().homePath(),
75+
"")
76+
if len(fileNames) > 0:
77+
self.model = fileNames
78+
self.ui.compute_SmileGAN_Btn.setEnabled(True)
79+
model_info = 'File:'
80+
for file in fileNames:
81+
model_info += file + '\n'
82+
self.ui.SPARE_model_info.setText(model_info)
83+
else:
84+
return
85+
86+
self.ui.stackedWidget.setCurrentIndex(0)
87+
'''
88+
89+
if 'RES_ICV_Sex_MUSE_Volume_47' in self.datamodel.GetColumnHeaderNames():
90+
self.ui.compute_SPARE_scores_Btn.setStyleSheet("QPushButton"
91+
"{"
92+
"background-color : rgb(230,255,230);"
93+
"}"
94+
"QPushButton::checked"
95+
"{"
96+
"background-color : rgb(255,230,230);"
97+
"border: none;"
98+
"}"
99+
)
100+
self.ui.compute_SPARE_scores_Btn.setEnabled(True)
101+
self.ui.compute_SPARE_scores_Btn.setChecked(False)
102+
self.ui.compute_SPARE_scores_Btn.setToolTip('Model loaded and `RES_ICV_Sex_MUSE_Volmue_*` available so the MUSE volumes can be harmonized.')
103+
else:
104+
self.ui.compute_SPARE_scores_Btn.setStyleSheet("background-color: rgb(255,230,230)")
105+
self.ui.compute_SPARE_scores_Btn.setEnabled(False)
106+
self.ui.compute_SPARE_scores_Btn.setToolTip('Model loaded but `RES_ICV_Sex_MUSE_Volmue_*` not available so the MUSE volumes can not be harmonized.')
107+
108+
109+
print('No field `RES_ICV_Sex_MUSE_Volume_47` found. ' +
110+
'Make sure to compute and add harmonized residuals first.')
111+
'''
112+
113+
114+
def OnComputationDone(self, y_hat):
115+
self.SPAREs = y_hat
116+
self.ui.compute_SmileGAN_Btn.setText('Compute SPARE-*')
117+
if self.SPAREs.empty:
118+
return
119+
self.ui.compute_SmileGAN_Btn.setChecked(False)
120+
self.ui.stackedWidget.setCurrentIndex(1)
121+
self.ui.comboBoxHue.setVisible(False)
122+
self.plotSPAREs(False)
123+
124+
# Activate buttons
125+
self.ui.compute_SmileGAN_Btn.setEnabled(False)
126+
if ('SPARE_BA' in self.datamodel.GetColumnHeaderNames() and
127+
'SPARE_AD' in self.datamodel.GetColumnHeaderNames()):
128+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
129+
#self.ui.show_SPARE_scores_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
130+
else:
131+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
132+
self.ui.load_SmileGAN_model_Btn.setEnabled(True)
133+
134+
135+
136+
def OnComputeSPAREs(self, checked):
137+
# Setup tasks for long running jobs
138+
# Using this example: https://realpython.com/python-pyqt-qthread/
139+
# Disable buttons
140+
if checked is not True:
141+
self.thread.requestInterruption()
142+
else:
143+
self.ui.compute_SmileGAN_Btn.setStyleSheet("QPushButton"
144+
"{"
145+
"background-color : rgb(230,255,230);"
146+
"}"
147+
"QPushButton::checked"
148+
"{"
149+
"background-color : rgb(255,230,230);"
150+
"}"
151+
)
152+
self.ui.compute_SmileGAN_Btn.setText('Cancel computation')
153+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
154+
self.ui.load_SmileGAN_model_Btn.setEnabled(False)
155+
self.thread = QtCore.QThread()
156+
self.worker = BrainAgeWorker(self.datamodel.data, self.model)
157+
self.worker.moveToThread(self.thread)
158+
self.thread.started.connect(self.worker.run)
159+
self.worker.done.connect(self.thread.quit)
160+
self.worker.done.connect(self.worker.deleteLater)
161+
self.thread.finished.connect(self.thread.deleteLater)
162+
self.worker.progress.connect(self.updateProgress)
163+
self.worker.done.connect(lambda y_hat: self.OnComputationDone(y_hat))
164+
self.ui.factorial_progressBar.setRange(0, len(self.model['BrainAge']['scaler'])-1)
165+
self.thread.start()
166+
167+
168+
def PopulateHue(self):
169+
#add the list items to comboBoxHue
170+
datakeys = self.datamodel.GetColumnHeaderNames()
171+
datatypes = self.datamodel.GetColumnDataTypes()
172+
categoryList = ['Sex','Study','A','T','N','PIB_Status'] + [k for k,d in zip(datakeys, datatypes) if d.name=='category']
173+
categoryList = list(set(categoryList).intersection(set(datakeys)))
174+
self.ui.comboBoxHue.clear()
175+
self.ui.comboBoxHue.addItems(categoryList)
176+
177+
178+
def plotSPAREs(self, useExistingSPAREs=True):
179+
# Plot data
180+
if self.ui.stackedWidget.currentIndex() == 0:
181+
return
182+
self.plotCanvas.axes.clear()
183+
plotOptions = {'HUE': self.ui.comboBoxHue.currentText()}
184+
185+
if useExistingSPAREs:
186+
print(self.SPAREs[plotOptions['HUE']].value_counts())
187+
kws = {"s": 20}
188+
sns.scatterplot(x='SPARE_AD', y='SPARE_BA', data=self.SPAREs,
189+
ax=self.plotCanvas.axes, linewidth=0, hue=plotOptions['HUE'],
190+
facecolor=(0.5, 0.5, 0.5, 0.5), **kws)
191+
else:
192+
kws = {"s": 20}
193+
sns.scatterplot(x='SPARE_AD', y='SPARE_BA', data=self.SPAREs,
194+
ax=self.plotCanvas.axes, linewidth=0,
195+
facecolor=(0.5, 0.5, 0.5, 0.5), legend=None)
196+
197+
198+
199+
sns.despine(ax=self.plotCanvas.axes, trim=True)
200+
self.plotCanvas.axes.set(ylabel='SPARE-BA', xlabel='SPARE-AD')
201+
self.plotCanvas.axes.get_figure().set_tight_layout(True)
202+
self.plotCanvas.draw()
203+
204+
205+
def OnAddToDataFrame(self):
206+
print('Adding SPARE-* scores to data frame...')
207+
self.datamodel.data.loc[:,'SPARE_AD'] = self.SPAREs['SPARE_AD']
208+
self.datamodel.data.loc[:,'SPARE_BA'] = self.SPAREs['SPARE_BA']
209+
self.datamodel.data_changed.emit()
210+
self.OnShowSPAREs()
211+
212+
213+
def OnShowSPAREs(self):
214+
allHues = [self.ui.comboBoxHue.itemText(i) for i in range(self.ui.comboBoxHue.count())]
215+
216+
self.SPAREs = self.datamodel.data[['SPARE_BA', 'SPARE_AD'] + allHues]
217+
self.ui.stackedWidget.setCurrentIndex(1)
218+
self.ui.comboBoxHue.setVisible(True)
219+
self.plotSPAREs()
220+
221+
222+
def OnDataChanged(self):
223+
# Set `Show SPARE-* from data` button to visible when SPARE-* columns
224+
# are present in data frame
225+
self.ui.stackedWidget.setCurrentIndex(0)
226+
if ('SPARE_BA' in self.datamodel.GetColumnHeaderNames() and
227+
'SPARE_AD' in self.datamodel.GetColumnHeaderNames()):
228+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
229+
self.ui.show_SmileGAN_prob_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
230+
else:
231+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
232+
233+
self.PopulateHue()
234+
235+
236+
class BrainAgeWorker(QtCore.QObject):
237+
238+
done = QtCore.pyqtSignal(pd.DataFrame)
239+
progress = QtCore.pyqtSignal(str, int)
240+
241+
#constructor
242+
def __init__(self, data, model):
243+
super(BrainAgeWorker, self).__init__()
244+
self.data = data
245+
self.model = model
246+
247+
def run(self):
248+
y_hat = pd.DataFrame.from_dict({'SPARE_BA': np.full((self.data.shape[0],),np.nan),
249+
'SPARE_AD': np.full((self.data.shape[0],),np.nan)})
250+
251+
# SPARE-BA
252+
idx = ~self.data[self.model['BrainAge']['predictors'][0]].isnull()
253+
254+
y_hat_test = np.zeros((np.sum(idx),))
255+
n_ensembles = np.zeros((np.sum(idx),))
256+
257+
for i,_ in enumerate(self.model['BrainAge']['scaler']):
258+
# Predict validation (fold) and test
259+
self.progress.emit('Computing SPARE-BA | Task 1 of 2', i)
260+
if QtCore.QThread.currentThread().isInterruptionRequested():
261+
self.progress.emit('Cancelled.', 0)
262+
self.done.emit(pd.DataFrame())
263+
return
264+
test = np.logical_not(self.data[idx]['participant_id'].isin(np.concatenate(self.model['BrainAge']['train']))) | self.data[idx]['participant_id'].isin(self.model['BrainAge']['validation'][i])
265+
X = self.data[idx].loc[test, self.model['BrainAge']['predictors']].values
266+
X = self.model['BrainAge']['scaler'][i].transform(X)
267+
y_hat_test[test] += (self.model['BrainAge']['svm'][i].predict(X) - self.model['BrainAge']['bias_ints'][i]) / self.model['BrainAge']['bias_slopes'][i]
268+
n_ensembles[test] += 1.
269+
270+
y_hat_test /= n_ensembles
271+
y_hat.loc[idx, 'SPARE_BA'] = y_hat_test
272+
273+
idx = ~self.data[self.model['AD']['predictors'][0]].isnull()
274+
275+
y_hat_test = np.zeros((np.sum(idx),))
276+
n_ensembles = np.zeros((np.sum(idx),))
277+
278+
for i,_ in enumerate(self.model['AD']['scaler']):
279+
# Predict validation (fold) and test
280+
self.progress.emit('Computing SPARE-AD | Task 2 of 2', i)
281+
if QtCore.QThread.currentThread().isInterruptionRequested():
282+
self.progress.emit('Cancelled.', 0)
283+
# Emit the result
284+
self.done.emit(pd.DataFrame())
285+
return
286+
test = np.logical_not(self.data[idx]['participant_id'].isin(np.concatenate(self.model['AD']['train']))) | self.data[idx]['participant_id'].isin(self.model['AD']['validation'][i])
287+
X = self.data[idx].loc[test, self.model['AD']['predictors']].values
288+
X = self.model['AD']['scaler'][i].transform(X)
289+
y_hat_test[test] += self.model['AD']['svm'][i].decision_function(X)
290+
n_ensembles[test] += 1.
291+
292+
y_hat_test /= n_ensembles
293+
y_hat.loc[idx, 'SPARE_AD'] = y_hat_test
294+
295+
self.progress.emit('All done.', i)
296+
297+
# Emit the result
298+
self.done.emit(y_hat)

0 commit comments

Comments
 (0)