Skip to content

Commit b67dcc0

Browse files
author
zhijian-yang
committed
Add SmileGAN Plugin
1 parent 9dac338 commit b67dcc0

File tree

5 files changed

+366
-0
lines changed

5 files changed

+366
-0
lines changed

NiBAx/plugins/SmileGAN/SmileGAN.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from PyQt5.QtGui import *
2+
from matplotlib.backends.backend_qt5 import FigureCanvasQT
3+
from PyQt5 import QtGui, QtCore, QtWidgets, uic
4+
import joblib
5+
import sys, os, time
6+
import seaborn as sns
7+
import numpy as np
8+
import pandas as pd
9+
from NiBAx.core.plotcanvas import PlotCanvas
10+
from NiBAx.core.baseplugin import BasePlugin
11+
from NiBAx.core.gui.SearchableQComboBox import SearchableQComboBox
12+
from SmileGAN2.Smile_GAN_clustering import clustering_result
13+
14+
class computeSmileGANs(QtWidgets.QWidget,BasePlugin):
15+
16+
#constructor
17+
def __init__(self):
18+
super(computeSmileGANs,self).__init__()
19+
self.model = []
20+
root = os.path.dirname(__file__)
21+
self.readAdditionalInformation(root)
22+
self.ui = uic.loadUi(os.path.join(root, 'SmileGAN.ui'),self)
23+
self.ui.comboBoxHue = SearchableQComboBox(self.ui)
24+
self.ui.horizontalLayout_3.addWidget(self.comboBoxHue)
25+
self.plotCanvas = PlotCanvas(self.ui.page_2)
26+
self.ui.verticalLayout.addWidget(self.plotCanvas)
27+
self.plotCanvas.axes = self.plotCanvas.fig.add_subplot(111)
28+
self.SmileGANpatterns = None
29+
self.ui.stackedWidget.setCurrentIndex(0)
30+
31+
# Initialize thread
32+
self.thread = QtCore.QThread()
33+
34+
35+
def getUI(self):
36+
return self.ui
37+
38+
39+
def SetupConnections(self):
40+
#pass
41+
self.ui.load_SmileGAN_model_Btn.clicked.connect(lambda: self.OnLoadSmileGANModel())
42+
self.ui.load_other_model_Btn.clicked.connect(lambda: self.OnLoadSmileGANModel())
43+
self.ui.add_to_dataframe_Btn.clicked.connect(lambda: self.OnAddToDataFrame())
44+
self.ui.compute_SmileGAN_Btn.clicked.connect(lambda check: self.OnComputePatterns(check))
45+
self.ui.show_SmileGAN_prob_from_data_Btn.clicked.connect(lambda: self.OnShowPatterns())
46+
self.datamodel.data_changed.connect(lambda: self.OnDataChanged())
47+
self.ui.comboBoxHue.currentIndexChanged.connect(self.plotPattern)
48+
49+
self.ui.add_to_dataframe_Btn.setStyleSheet("background-color: green; color: white")
50+
# Set `Show SmileGAN patterns from data` button to visible when SmileGAN_Pattern column
51+
# are present in data frame
52+
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()):
53+
self.ui.show_SmileGAN_prob_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
54+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
55+
self.ui.show_SmileGAN_prob_from_data_Btn.setToolTip('The data frame has variables `SmileGAN patterns` so these can be plotted.')
56+
else:
57+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
58+
59+
# Allow loading of SmileGAN-* model always, even when residuals are not
60+
# calculated yet
61+
self.ui.load_SmileGAN_model_Btn.setEnabled(True)
62+
63+
64+
def OnLoadSmileGANModel(self):
65+
fileNames, _ = QtWidgets.QFileDialog.getOpenFileNames(None,
66+
'Open SmileGAN model file',
67+
QtCore.QDir().homePath(),
68+
"")
69+
if len(fileNames) > 0:
70+
self.model = fileNames
71+
self.ui.compute_SmileGAN_Btn.setEnabled(True)
72+
model_info = 'File:'
73+
for file in fileNames:
74+
model_info += file + '\n'
75+
self.ui.SmileGAN_model_info.setText(model_info)
76+
else:
77+
return
78+
79+
self.ui.stackedWidget.setCurrentIndex(0)
80+
81+
if 'RES_ICV_Sex_MUSE_Volume_47' in self.datamodel.GetColumnHeaderNames():
82+
self.ui.compute_SmileGAN_Btn.setStyleSheet("QPushButton"
83+
"{"
84+
"background-color : rgb(230,255,230);"
85+
"}"
86+
"QPushButton::checked"
87+
"{"
88+
"background-color : rgb(255,230,230);"
89+
"border: none;"
90+
"}"
91+
)
92+
self.ui.compute_SmileGAN_Btn.setEnabled(True)
93+
self.ui.compute_SmileGAN_Btn.setChecked(False)
94+
self.ui.compute_SmileGAN_Btn.setToolTip('Model loaded and `RES_ICV_Sex_MUSE_Volmue_*` available so the MUSE volumes can be harmonized.')
95+
else:
96+
self.ui.compute_SmileGAN_Btn.setStyleSheet("background-color: rgb(255,230,230)")
97+
self.ui.compute_SmileGAN_Btn.setEnabled(False)
98+
self.ui.compute_SmileGAN_Btn.setToolTip('Model loaded but `RES_ICV_Sex_MUSE_Volmue_*` not available so the MUSE volumes can not be harmonized.')
99+
100+
101+
print('No field `RES_ICV_Sex_MUSE_Volume_47` found. ' +
102+
'Make sure to compute and add harmonized residuals first.')
103+
104+
def OnComputationDone(self, p):
105+
self.SmileGANpatterns = p
106+
self.ui.compute_SmileGAN_Btn.setText('Compute SmileGAN Patterns-*')
107+
if self.SmileGANpatterns.empty:
108+
return
109+
self.ui.compute_SmileGAN_Btn.setChecked(False)
110+
self.ui.stackedWidget.setCurrentIndex(1)
111+
self.ui.comboBoxHue.setVisible(False)
112+
self.plotPattern()
113+
114+
# Activate buttons
115+
self.ui.compute_SmileGAN_Btn.setEnabled(False)
116+
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()):
117+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
118+
else:
119+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
120+
self.ui.load_SmileGAN_model_Btn.setEnabled(True)
121+
122+
123+
124+
def OnComputePatterns(self, checked):
125+
# Setup tasks for long running jobs
126+
# Using this example: https://realpython.com/python-pyqt-qthread/
127+
# Disable buttons
128+
if checked is not True:
129+
self.thread.requestInterruption()
130+
else:
131+
self.ui.compute_SmileGAN_Btn.setStyleSheet("QPushButton"
132+
"{"
133+
"background-color : rgb(230,255,230);"
134+
"}"
135+
"QPushButton::checked"
136+
"{"
137+
"background-color : rgb(255,230,230);"
138+
"}"
139+
)
140+
self.ui.compute_SmileGAN_Btn.setText('Cancel computation')
141+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
142+
self.ui.load_SmileGAN_model_Btn.setEnabled(False)
143+
self.thread = QtCore.QThread()
144+
self.worker = PatternWorker(self.datamodel.data, self.model)
145+
self.worker.moveToThread(self.thread)
146+
self.thread.started.connect(self.worker.run)
147+
self.worker.done.connect(self.thread.quit)
148+
self.worker.done.connect(self.worker.deleteLater)
149+
self.thread.finished.connect(self.thread.deleteLater)
150+
self.worker.done.connect(lambda p: self.OnComputationDone(p))
151+
self.thread.start()
152+
153+
154+
def plotPattern(self):
155+
# Plot data
156+
if self.ui.stackedWidget.currentIndex() == 0:
157+
return
158+
self.plotCanvas.axes.clear()
159+
160+
sns.countplot(x='Pattern', data=self.SmileGANpatterns,
161+
ax=self.plotCanvas.axes)
162+
163+
sns.despine(ax=self.plotCanvas.axes, trim=True)
164+
self.plotCanvas.axes.set(ylabel='Count', xlabel='Patterns')
165+
self.plotCanvas.axes.get_figure().set_tight_layout(True)
166+
self.plotCanvas.draw()
167+
168+
169+
def OnAddToDataFrame(self):
170+
print('Adding SmileGAN patterns to data frame...')
171+
for col in self.SmileGANpatterns.columns:
172+
self.datamodel.data.loc[:,'SmileGAN_'+col] = self.SmileGANpatterns[col]
173+
self.datamodel.data_changed.emit()
174+
self.OnShowPatterns()
175+
176+
177+
def OnShowPatterns(self):
178+
self.ui.stackedWidget.setCurrentIndex(1)
179+
self.plotPattern()
180+
181+
182+
def OnDataChanged(self):
183+
# Set `Show SmileGAN patterns from data` button to visible when SmileGAN_Pattern column
184+
# are present in data frame
185+
self.ui.stackedWidget.setCurrentIndex(0)
186+
if ('SmileGAN_Pattern' in self.datamodel.GetColumnHeaderNames()):
187+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(True)
188+
self.ui.show_SmileGAN_prob_from_data_Btn.setStyleSheet("background-color: rgb(230,230,255)")
189+
else:
190+
self.ui.show_SmileGAN_prob_from_data_Btn.setEnabled(False)
191+
192+
193+
class PatternWorker(QtCore.QObject):
194+
195+
done = QtCore.pyqtSignal(pd.DataFrame)
196+
progress = QtCore.pyqtSignal(str, int)
197+
198+
#constructor
199+
def __init__(self, data, model_list):
200+
super(PatternWorker, self).__init__()
201+
self.data = data
202+
self.model = model_list
203+
204+
def run(self):
205+
train_data = self.data[['participant_id']+[ name for name in self.data.columns if ('H_MUSE_Volume' in name and int(name[14:])<300)] ]
206+
covariate = self.data[['participant_id','Age','Sex']]
207+
covariate['Sex'] = covariate['Sex'].map({'M':1,'F':0})
208+
train_data['diagnosis'] = 1
209+
covariate['diagnosis'] = 1
210+
cluster_label, cluster_prob, _, _ = clustering_result(self.model, 'highest_matching_clustering', train_data, covariate)
211+
p = pd.DataFrame(data = cluster_prob, columns = ['P'+str(_) for _ in range(1,cluster_prob.shape[1]+1)])
212+
p['Pattern'] = cluster_label
213+
214+
# Emit the result
215+
self.done.emit(p)

NiBAx/plugins/SmileGAN/SmileGAN.ui

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<ui version="4.0">
3+
<class>Form</class>
4+
<widget class="QWidget" name="Form">
5+
<property name="geometry">
6+
<rect>
7+
<x>0</x>
8+
<y>0</y>
9+
<width>891</width>
10+
<height>695</height>
11+
</rect>
12+
</property>
13+
<property name="windowTitle">
14+
<string>Compute SPAREs</string>
15+
</property>
16+
<layout class="QGridLayout" name="gridLayout">
17+
<item row="0" column="0">
18+
<widget class="QStackedWidget" name="stackedWidget">
19+
<property name="styleSheet">
20+
<string notr="true"/>
21+
</property>
22+
<property name="currentIndex">
23+
<number>0</number>
24+
</property>
25+
<widget class="QWidget" name="page">
26+
<layout class="QVBoxLayout" name="verticalLayout_2">
27+
<item>
28+
<spacer name="verticalSpacer_2">
29+
<property name="orientation">
30+
<enum>Qt::Vertical</enum>
31+
</property>
32+
<property name="sizeHint" stdset="0">
33+
<size>
34+
<width>338</width>
35+
<height>241</height>
36+
</size>
37+
</property>
38+
</spacer>
39+
</item>
40+
<item>
41+
<widget class="QPushButton" name="show_SmileGAN_prob_from_data_Btn">
42+
<property name="enabled">
43+
<bool>false</bool>
44+
</property>
45+
<property name="text">
46+
<string>Show SmileGAN patterns from data</string>
47+
</property>
48+
</widget>
49+
</item>
50+
<item>
51+
<widget class="QPushButton" name="load_SmileGAN_model_Btn">
52+
<property name="text">
53+
<string>Load SmileGAN Model</string>
54+
</property>
55+
</widget>
56+
</item>
57+
<item>
58+
<widget class="QLabel" name="SmileGAN_model_info">
59+
<property name="text">
60+
<string>No SmileGAN model loaded</string>
61+
</property>
62+
</widget>
63+
</item>
64+
<item>
65+
<widget class="QPushButton" name="compute_SmileGAN_Btn">
66+
<property name="enabled">
67+
<bool>false</bool>
68+
</property>
69+
<property name="text">
70+
<string>Compute SmileGAN patterns</string>
71+
</property>
72+
<property name="checkable">
73+
<bool>true</bool>
74+
</property>
75+
<property name="checked">
76+
<bool>false</bool>
77+
</property>
78+
</widget>
79+
</item>
80+
<item>
81+
<spacer name="verticalSpacer">
82+
<property name="orientation">
83+
<enum>Qt::Vertical</enum>
84+
</property>
85+
<property name="sizeHint" stdset="0">
86+
<size>
87+
<width>338</width>
88+
<height>241</height>
89+
</size>
90+
</property>
91+
</spacer>
92+
</item>
93+
<item>
94+
<widget class="Line" name="line">
95+
<property name="orientation">
96+
<enum>Qt::Horizontal</enum>
97+
</property>
98+
</widget>
99+
</item>
100+
</layout>
101+
</widget>
102+
<widget class="QWidget" name="page_2">
103+
<layout class="QVBoxLayout" name="verticalLayout">
104+
<item>
105+
<layout class="QHBoxLayout" name="horizontalLayout_3">
106+
<item>
107+
<widget class="QLabel" name="label">
108+
<property name="text">
109+
<string>SmileGAN</string>
110+
</property>
111+
</widget>
112+
</item>
113+
<item>
114+
<widget class="QPushButton" name="load_other_model_Btn">
115+
<property name="text">
116+
<string>Load other model</string>
117+
</property>
118+
</widget>
119+
</item>
120+
<item>
121+
<widget class="QPushButton" name="add_to_dataframe_Btn">
122+
<property name="text">
123+
<string>Add to DataFrame</string>
124+
</property>
125+
</widget>
126+
</item>
127+
</layout>
128+
</item>
129+
</layout>
130+
</widget>
131+
</widget>
132+
</item>
133+
</layout>
134+
</widget>
135+
<resources/>
136+
<connections/>
137+
</ui>
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[Core]
2+
Name = SmileGAN
3+
Module = SmileGAN
4+
5+
[Documentation]
6+
Author = Zhijian Yang
7+
Version = 0.1
8+
Website =
9+
Description = Compute SmileGAN patterns from existing model.
10+
11+
[Tab]
12+
#tab position starts from 0
13+
Position = 5

NiBAx/plugins/SmileGAN/__init__.py

Whitespace-only changes.

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ six==1.16.0
2323
statsmodels==0.13.0
2424
wheel>=0.37.1
2525
Yapsy==1.12.2
26+
SmileGAN==0.1.1

0 commit comments

Comments
 (0)