Skip to content

Commit e0c3a63

Browse files
committed
Add capability to plot model trends without data
1 parent 13f30cc commit e0c3a63

File tree

4 files changed

+41
-23
lines changed

4 files changed

+41
-23
lines changed

niCHART/core/model/datamodel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class DataModel(QObject):
2121
"""This class holds the data model."""
2222

2323
data_changed = QtCore.pyqtSignal()
24+
model_changed = QtCore.pyqtSignal()
2425

2526
def __init__(self):
2627
QObject.__init__(self)
@@ -90,7 +91,8 @@ def SetData(self,d):
9091
def SetHarmonizationModel(self,m):
9192
"""Setter for neuroHarmonize model"""
9293
self.harmonization_model = m
93-
logger.info('neuroHarmonize model set.')
94+
logger.info('neuroHarmonize model set. Signal emmitted')
95+
self.model_changed.emit()
9496

9597

9698
def SetSPAREModel(self,BrainAgeModel, ADModel):
@@ -190,6 +192,8 @@ def GetColumnHeaderNames(self):
190192
"""Returns all header names for all columns in the dataset."""
191193
if self.data is not None:
192194
k = self.data.keys()
195+
elif self.harmonization_model is not None:
196+
k = self.harmonization_model['ROIs']
193197
else:
194198
k = []
195199

niCHART/plugins/agetrends/agetrends.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import matplotlib.pyplot as plt
99
import numpy as np
1010
import pandas as pd
11+
from niCHART.plugins.loadsave.dataio import DataIO
1112
from niCHART.core.plotcanvas import PlotCanvas
1213
from niCHART.core.gui.SearchableQComboBox import SearchableQComboBox
1314

@@ -33,13 +34,24 @@ def getUI(self):
3334

3435
def SetupConnections(self):
3536
self.datamodel.data_changed.connect(lambda: self.OnDataChanged())
37+
self.datamodel.model_changed.connect(lambda: self.OnModelChanged())
3638
self.ui.comboBoxROI.currentIndexChanged.connect(self.UpdatePlot)
3739
self.ui.comboBoxHue.currentIndexChanged.connect(self.UpdatePlot)
3840

3941
def OnDataChanged(self):
4042
self.PopulateROI()
4143
self.PopulateHue()
4244

45+
def OnModelChanged(self):
46+
self.GetMUSEROIDict()
47+
self.PopulateROI()
48+
49+
def GetMUSEROIDict(self):
50+
dio = DataIO()
51+
#also read MUSE dictionary
52+
MUSEDictNAMEtoID, MUSEDictIDtoNAME, MUSEDictDataFrame = dio.ReadMUSEDictionary()
53+
self.datamodel.SetMUSEDictionaries(MUSEDictNAMEtoID, MUSEDictIDtoNAME,MUSEDictDataFrame)
54+
4355
def PopulateROI(self):
4456
#get data column header names
4557
datakeys = self.datamodel.GetColumnHeaderNames()
@@ -58,7 +70,6 @@ def PopulateROI(self):
5870
if invalid_ROI in roiList:
5971
roiList.remove(invalid_ROI)
6072

61-
6273
_, MUSEDictIDtoNAME = self.datamodel.GetMUSEDictionaries()
6374
roiList = list(set(roiList).intersection(set(datakeys)))
6475
roiList.sort()
@@ -131,16 +142,16 @@ def PlotAgeTrends(self,plotOptions):
131142
self.plotCanvas.axes.clear()
132143

133144
# seaborn plot on axis
134-
a = sns.scatterplot(x='Age', y=currentROI, hue=currentHue,ax=self.plotCanvas.axes, s=5,
135-
data=self.datamodel.GetData(currentROI,currentHue))
136-
self.plotCanvas.axes.yaxis.set_ticks_position('left')
137-
self.plotCanvas.axes.xaxis.set_ticks_position('bottom')
138-
sns.despine(fig=self.plotCanvas.axes.get_figure(), trim=True)
139-
self.plotCanvas.axes.get_figure().set_tight_layout(True)
140-
141-
# Plot normative range if according GAM model is available
142-
if (self.datamodel.harmonization_model is not None) and (currentROI in ['H_' + x for x in self.datamodel.harmonization_model['ROIs']]):
143-
x,y,z = self.datamodel.GetNormativeRange(currentROI[2:])
145+
if self.datamodel.data is not None:
146+
a = sns.scatterplot(x='Age', y=currentROI, hue=currentHue,ax=self.plotCanvas.axes, s=5,
147+
data=self.datamodel.GetData(currentROI,currentHue))
148+
self.plotCanvas.axes.yaxis.set_ticks_position('left')
149+
self.plotCanvas.axes.xaxis.set_ticks_position('bottom')
150+
sns.despine(fig=self.plotCanvas.axes.get_figure(), trim=True)
151+
self.plotCanvas.axes.get_figure().set_tight_layout(True)
152+
153+
if (self.datamodel.harmonization_model is not None) and (currentROI in [x for x in self.datamodel.harmonization_model['ROIs']]):
154+
x,y,z = self.datamodel.GetNormativeRange(currentROI)
144155
#print('Pooled variance: %f' % (z))
145156
# Plot three lines as expected mean and +/- 2 times standard deviation
146157
sns.lineplot(x=x, y=y, ax=self.plotCanvas.axes, linestyle='-', markers=False, color='k')

niCHART/plugins/harmonizationplugin/harmonization.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def DoHarmonization(self):
3535

3636
covars = self.datamodel.data[['SITE','Age','Sex','DLICV_baseline']].reset_index(drop=True).copy()
3737
covars.loc[:,'Sex'] = covars['Sex'].map({'M':1,'F':0})
38-
covars.loc[covars.Age>100, 'Age']=100
38+
covars.loc[covars.Age>102, 'Age']=102
39+
covars.loc[covars.Age<20, 'Age']=20
3940

4041
# Parameter table for plotting
4142
gamma_ROIs = ['gamma_'+ x for x in self.datamodel.harmonization_model['ROIs']]
@@ -73,9 +74,11 @@ def DoHarmonization(self):
7374
model_delta = pd.DataFrame(self.datamodel.harmonization_model['delta_star'],columns=delta_ROIs,index=[x for x in self.datamodel.harmonization_model['SITE_labels']])
7475
parameters = pd.concat([model_gamma,model_delta],axis=1).sort_index()
7576
else:
76-
oos_data = self.datamodel.data[self.datamodel.data['SITE'].isin(sites_to_harmonize)].dropna(subset=covariates)[[x for x in self.datamodel.harmonization_model['ROIs']]].values
77-
oos_covars = self.datamodel.data[self.datamodel.data.SITE.isin(sites_to_harmonize)].dropna(subset=covariates)[covariates]
77+
oos_data = self.datamodel.data[(self.datamodel.data['SITE'].isin(sites_to_harmonize))&(self.datamodel.data['UseForComBatGAMHarmonization'].notnull())].dropna(subset=covariates)[[x for x in self.datamodel.harmonization_model['ROIs']]].values
78+
oos_covars = self.datamodel.data[(self.datamodel.data['SITE'].isin(sites_to_harmonize))&(self.datamodel.data['UseForComBatGAMHarmonization'].notnull())].dropna(subset=covariates)[covariates]
7879
oos_covars.loc[:,'Sex'] = oos_covars['Sex'].map({'M':1,'F':0})
80+
oos_covars.loc[oos_covars.Age>102, 'Age']=102
81+
oos_covars.loc[oos_covars.Age<20, 'Age']=20
7982
self.model, _ = nh.harmonizationLearn(oos_data, oos_covars,
8083
smooth_terms=['Age'],
8184
smooth_term_bounds=(np.floor(np.min(self.datamodel.data.Age)),np.ceil(np.max(self.datamodel.data.Age))),
@@ -122,7 +125,7 @@ def DoHarmonization(self):
122125
MUSEDictDataFrame= self.datamodel.GetMUSEDictDataFrame()
123126
muse_mappings = self.datamodel.GetDerivedMUSEMap()
124127
for ROI in MUSEDictDataFrame[MUSEDictDataFrame['ROI_LEVEL']=='DERIVED']['ROI_INDEX']:
125-
single_ROIs = muse_mappings.loc[ROI].replace('NaN',np.nan).dropna().astype(np.float)
128+
single_ROIs = muse_mappings.loc[ROI].replace('NaN',np.nan).dropna().astype(np.float64)
126129
single_ROIs = ['H_MUSE_Volume_%0d' % x for x in single_ROIs]
127130
muse['H_MUSE_Volume_%d' % ROI] = muse[single_ROIs].sum(axis=1,skipna=False)
128131
muse.drop(columns=['H_MUSE_Volume_702'], inplace=True)

niCHART/plugins/harmonizationplugin/harmonizationplugin.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@ def getUI(self):
4242

4343
def SetupConnections(self):
4444
self.ui.load_harmonization_model_Btn.clicked.connect(lambda: self.OnLoadHarmonizationModelBtnClicked())
45-
if self.datamodel.data is None:
46-
self.ui.load_harmonization_model_Btn.setEnabled(False)
47-
self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection.\nReturn to Load and Save Data tab to select data.')
4845
self.ui.load_other_model_Btn.clicked.connect(lambda: self.OnLoadHarmonizationModelBtnClicked())
4946
self.ui.show_data_Btn.clicked.connect(lambda: self.OnShowDataBtnClicked())
5047
self.ui.apply_model_to_dataset_Btn.clicked.connect(lambda: self.OnApplyModelToDatasetBtnClicked())
@@ -98,13 +95,17 @@ def LoadHarmonizationModel(self, filename):
9895
age_min = self.datamodel.harmonization_model['smooth_model']['bsplines_constructor'].knot_kwds[0]['lower_bound']
9996
model_text4 = ('Valid Age Range: [' + str(age_min) + ', ' + str(age_max) + ']')
10097
model_text1 += '\n'+model_text4
101-
self.ui.Harmonized_Data_Information_Lbl.setText(model_text1)
10298
if self.datamodel.data is None:
10399
self.ui.apply_model_to_dataset_Btn.setEnabled(False)
104-
self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection or application.\nReturn to Load and Save Data tab to select data.')
100+
model_text5 = 'Data must be loaded before model application.\nReturn to Load and Save Data tab to select data.'
101+
model_text1 += '\n'+model_text5
102+
self.ui.Harmonized_Data_Information_Lbl.setText(model_text1)
105103
else:
104+
self.ui.Harmonized_Data_Information_Lbl.setText(model_text1)
106105
self.ui.apply_model_to_dataset_Btn.setEnabled(True)
107106
self.ui.apply_model_to_dataset_Btn.setStyleSheet("background-color: rgb(230,255,230); color: black")
107+
self.datamodel.SetDataFilePath(filename)
108+
self.datamodel.SetHarmonizationModel(self.datamodel.harmonization_model)
108109
self.ui.stackedWidget.setCurrentIndex(0)
109110

110111
def OnLoadHarmonizationModelBtnClicked(self):
@@ -339,9 +340,8 @@ def OnDataChanged(self):
339340
self.ui.show_data_Btn.setEnabled(False)
340341

341342
if self.datamodel.data is None:
342-
self.ui.load_harmonization_model_Btn.setEnabled(False)
343343
self.ui.apply_model_to_dataset_Btn.setEnabled(False)
344-
self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model selection.\nReturn to Load and Save Data tab to select data.')
344+
self.ui.Harmonized_Data_Information_Lbl.setText('Data must be loaded before model application.\nReturn to Load and Save Data tab to select data.')
345345
else:
346346
self.ui.load_harmonization_model_Btn.setEnabled(True)
347347
if self.datamodel.harmonization_model is None:

0 commit comments

Comments
 (0)