-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlime_interpretation.Rmd
363 lines (291 loc) · 15.1 KB
/
lime_interpretation.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
---
title: "Global and Local Model Interpretation"
author: "[Aaron Coyner](https://github.com/aaroncoyner/)"
date: "`r Sys.Date()`"
output:
html_notebook:
toc: true
toc_depth: 3
toc_float: true
---
```{r setup, include=F}
knitr::opts_chunk$set(echo = TRUE)
```
## Data Preparation
First, we'll load the necessary libraries and the Cardiovascular Disease Risk Dataset.
```{r, message=F}
## Install cvdRiskData using devtools::install_github("laderast/cvdRiskData")
library(cvdRiskData)
library(tidyverse)
library(keras)
library(lime)
library(caret)
library(ROCR)
library(corrr)
library(tidyquant)
## Set seed for repeatability
set.seed(655)
## Load data from cvdRiskData package
data(cvd_patient)
## Remove columns patientID, (binned) aged (i.e. 10-20, 21-30, etc.), and race
cvd_patient <- select(cvd_patient, -patientID, -age, -race)
summary(cvd_patient)
```
We see that we have 9 predictors:
- `htn`: patient has (1) or does not have (0) hypertension
- `treat`: patient is (1) or is not (0) receiving treatment for hypertension
- `smoking`: patient is (1) or is not (0) a smoker
- `t2d`: patient has (1) or does not have (0) type II diabetes
- `gender`: the patient is male(1) or female(0)
- `numAge`: patient age in years
- `bmi`: patient's BMI
- `tchol`: patient's total cholesterol value
- `sbp`: patient's systolic blood pressure
### Split the Data
To train and validate our model, we need separate it into separate training and testing datasets. We'll use the `createDataPartition()` function availabe in the `caret` package. We will perform and 80/20 split on the dataset.
After splitting, we inspect the distribution of CVD patients in the training dataset and find that the `N` cases (people without CVD) are more highly represented. This can lead machine learning models to overfit and underfit to the over- and under-represented classes, respectively.
```{r}
## Partition dataset into train (80%) and test (20%)
train_idx <- createDataPartition(cvd_patient$cvd, p = 0.80, list = FALSE)
train_data <- cvd_patient[train_idx, ]
test_data <- cvd_patient[-train_idx, ]
## Inspect class distribution
table(train_data$cvd)
```
### Downsample the Data
To mitigate this issue, we will use downsampling via the `downSample()` function available in `caret`. This function randomly removes cases from the over-represented class until the class distributions are equal. Two minor issues with this function are that it (A) creates a new column called `Class` which contains the same data in the `cvd` column and (B) sorts the data by class. So, we'll need to remove the `Class` column and reshuffle our dataset.
```{r}
## Downsample the dataset so that classes are balanced
train_data <-
train_data %>%
downSample(train_data$cvd) %>%
select(-Class)
## Because downSample() groups by class, reshuffle the downsampled dataset
shuffle_idx <- sample(nrow(train_data))
train_data <- train_data[shuffle_idx,]
## Inspect class distribution
table(train_data$cvd)
```
### Scale the Data
For many machine learning models to perform adequately, the input data must be scaled into the [-1, 1] or [1, 1] range. We create a simple function, `scale_data`, that takes as input the dataframe we'd like to scale and the column indeces of that dataframe that should be kept. It then converts that dataframe into a matrix, where all variables are converted to numeric. We then apply the function $f(x) = \frac{x - min(x)}{max(x) - min(x)}$ to each column specified by `keep` and return the dataframe.
NOTE: In practice, we would scale the test dataset by the same values used to scale the training dataset. To keep things simple, we're just going to scale each dataset by its own features.
```{r}
scale_data <- function(data, keep) {
data[keep] %>%
data.matrix() %>%
apply(2, function(x) (x - min(x)) / (max(x) - min(x))) %>%
return()
}
## Scale numeric data in train and test datasets into the [0, 1] range
x_train <- scale_data(train_data, 1:9)
y_train <- scale_data(train_data, 10)
x_test <- scale_data(test_data, 1:9)
y_test <- scale_data(test_data, 10)
head(x_train)
head(y_train)
```
## Model Training
### Setup the Model
Now we can begin setting up our model. The command `Sys.setenv('KMP_DUPLICATE_LIB_OK'='T')` is simply here to prevent an error that can occur for Mac users with certain Intel CPUs. You might not need it, but it also won't hurt.
So let's build our simple feedforward neural network using `keras`, an API for Google's TensorFlow. We tell R that we'd like to build a sequential model (i.e. information is passed from one layer to the next), then we add a hidden layer with 100 hidden units, followed by another hidden layer with 50 units, and an output layer with 1 unit. Because we are performing binary classification, the output node has a simgoidal activation unit. The other layers use a [rectified linear unit](https://www.kaggle.com/dansbecker/rectified-linear-units-relu-in-deep-learning) for their activation function.
You'll also notice that we've implemented [dropout](https://www.kaggle.com/pavansanagapati/dropout-regularization-deep-learning) between each of the layers. This is a regularization technique that reduces overfitting. We've set the `rate` to be 0.5 -- for each batch of training data passed through to the model, each node has a 50% chance of being randomly ignored. This means that, for that specific training batch, the model must learn to make predictions without those nodes. Essentially, this prevents neighboring nodes from becoming to reliant upon one another so that they may better generalize to other datasets.
```{r}
## Avoids errors on certain Intel CPUs on MacOS
Sys.setenv('KMP_DUPLICATE_LIB_OK'='T')
## Build the deep learning model
model <-
keras_model_sequential() %>%
layer_dense(100, activation = 'relu', input_shape = ncol(x_train)) %>%
layer_dropout(0.5) %>%
layer_dense(50, activation = 'relu') %>%
layer_dropout(0.5) %>%
layer_dense(1, activation = 'sigmoid') %>%
compile(
loss = "binary_crossentropy",
optimizer = optimizer_sgd(lr = 1.0, nesterov = T, momentum = 0.9),
metrics = c('acc')
)
```
### Train the Model
On to the fun part...training the model! We create an object, `train` which fits our model, `model`, using the specified parameters. Hopefully `x` and `y` are self-explanatory. `validation_split` tells the train generator that we'd like to use the *last* 20% of our dataset (this is why we had to reshuffle our dataset) as a held-out validation set so that we can monitor the training process and prevent overfitting.
``` {r results = "hide"}
## Fit the deep learning model to the training dataset
train <-
fit(
model,
x = x_train,
y = y_train,
validation_split = 0.20,
batch_size = 5000,
epochs = 10
)
## Plot the learning metrics
plot(train)
```
We can see that, during training, the training and validation accuracy and loss metrics tracked nicely together. This suggests that the model is not overfit to the dataset. If anything, it's slightly underfit, as suggested by the fact that the validation set accuracy is higher than the training set accuracy and the validation set loss is lower that the training set loss.
## Model Evaluation
### Compute Informative Model Metrics
Just because we've shown that our model is not overfit to that particular dataset, we still need to demonstrate its generalizability to the held out test dataset. To do so, we'll first have the model predict the probability of each subject in the test dataset as having CVD using `predict_proba()` from `keras`. We'll then use the `prediciton()` function in `caret` to compare the probabilities to the true test dataset labels. `performance()` uses those predictions to evaluate the true and false positive rates at each probability cutoff/threshold. This will create a receiver operating characteristics curve. We'll also evaluate the area under this curve as an overall measure of model performance.
```{r}
## Predict the probability of CVD on the test dataset
nn_predictions <- predict_proba(model, x_test)
## Compare the predictions to the true class label
nn_pr <- prediction(nn_predictions, y_test)
## Calculate the AUC of the model
nn_auc <- performance(nn_pr, measure = "auc")
nn_auc <- [email protected][[1]]
## Evaluate the sensitivity and specificity at each threshold.
## Save output into a dataframe.
nn_perf <- performance(nn_pr, measure = "tpr", x.measure = "fpr")
nn_perf_df <- data.frame([email protected], [email protected])
colnames(nn_perf_df) <- c("x_values", "y_values")
## Plot the results
ggplot(nn_perf_df, aes(x = x_values, y = y_values)) +
geom_line() +
geom_abline(slope = 1, intercept = 0, color = "red") +
annotate(
geom = "text",
x = 0.75,
y = 0.32,
label = paste('AUC: ', round(nn_auc, 3))
) +
labs(
title = "Receiver Operating Characteristics Curve",
subtitle = "Test Dataset",
x = "1 - Specificity",
y = "Sensitivity"
) +
theme_light()
```
Let's also use the `confusionMatrix()` function to evaluate model performance at the a probability threshold of 0.5. Not only does this function create a confusion matrix for us, but it also lists the model accuracy, balanced accuracy, sensitivity, specificity, and a few others.
``` {r}
## Predict the class (rather than probability) of CVD on the test dataset
nn_classes <- predict_classes(model, x_test)
## Create a confusion matrix
nn_conf <- confusionMatrix(as.factor(nn_classes), as.factor(y_test),positive = '1')
nn_conf
```
## Model Interpretability
Great! So far, we've built a fairly decent model for CVD prediction. It's far from perfect, but it's certainly better than chance. But how is the model making its predictions?
```{r}
summary(model)
```
This summary shows that, for our "simple" feedforward neural network, there are 6,101 trainable parameters (i.e. weighted connections). Practically, we cannot infer what the 1,000 connections between the input data and the first layer mean and, unfortunately, there's no way to know what the 5,050 connections between the first and second hidden layers or the 51 connections between the second hidden layer and the output layer actually mean.
Without this information, how are we supposed to use this model practically? Ultimately, our goal is to *prevent* CVD, not just detect it. **We need to learn what risk factors are associated with CVD.**
### Investigate Global Predictors
To begin, we'll fist take a look at global predictors of CVD -- we will perform a correlation analysis between `cvd` and the predictors in the training dataset. We'll use the `correlate()` function from the `corrr` package to do this, as it returns a correlation dataframe rather than a matrix. We can use this dataframe to better visualize the findings with a forest plot.
```{r}
## Create correlation table to examine features that correlate GLOBALLY
corrr_analysis <-
as.data.frame(x_train) %>%
mutate(cvd = y_train) %>%
mutate_if(is.factor, as.numeric) %>%
correlate(quiet = T) %>%
focus(cvd) %>%
rename(feature = rowname) %>%
arrange(abs(cvd)) %>%
mutate(feature = as_factor(feature))
corrr_analysis
## Create a forest plot of the global correlation metrics
corrr_analysis %>%
ggplot(aes(x = cvd, y = fct_reorder(feature, desc(cvd)))) +
geom_point() +
geom_segment(
aes(xend = 0, yend = feature),
color = palette_light()[[2]],
data = filter(corrr_analysis, cvd > 0)
) +
geom_point(
color = palette_light()[[2]],
data = filter(corrr_analysis, cvd > 0)
) +
geom_segment(
aes(xend = 0, yend = feature),
color = palette_light()[[1]],
data = filter(corrr_analysis, cvd < 0)
) +
geom_point(
color = palette_light()[[1]],
data = filter(corrr_analysis, cvd < 0)
) +
geom_vline(
xintercept = 0,
color = palette_light()[[5]],
size = 1,
linetype = 2
) +
geom_vline(
xintercept = -0.25,
color = palette_light()[[5]],
size = 1,
linetype = 2
) +
geom_vline(
xintercept = 0.25,
color = palette_light()[[5]],
size = 1,
linetype = 2
) +
theme_tq() +
labs(
title = 'CVD Correlation Analysis',
subtitle = 'Negative vs. Positive Correlations',
x = 'CVD Risk',
y = 'Feature Importance'
)
```
This plot shows that `numAge`, `htn`, `sbp`, and `treat` are correlated with `cvd`. **How are `htn`, `sbp` and `treat` related?** It could be argued that `gender` and `smoking` are loosely correlated with `cvd`, too. But how do we use this information? Do we just assume older people will have CVD? Older people with hypertension? Older people with hypertension who are also receiving treatment for hypertension? This plot is great for identifying the global drivers of CVD, but we need to investigate locally if we want to learn how these predictors interact.
### Investigate Local Predictors
We can use Local Interpretable Model-agnostic Explanations (LIME) to achive this. Unfortunately, the `lime` package is not setup out-of-the-box to work with `keras`, so we need to make two custom functions for it to work properly:
- `model_type()`: Used to tell lime what type of model we are dealing with. It could be classification, regression, survival, etc.
- `predict_model()`: Used to allow lime to perform predictions that its algorithm can interpret.
```{r}
## Define a new classification model type for LIME
model_type.keras.engine.sequential.Sequential <- function(x, ...) {
return("classification")
}
## Create a prediction wrapper around predict_proba for LIME
predict_model.keras.engine.sequential.Sequential <- function (x, newdata, type, ...) {
pred <- predict_proba(object = x, x = as.matrix(newdata))
return(data.frame(Positive = pred, Negative = 1 - pred))
}
```
Now, we create an explainer object and use it to explain 4 randomly-sampled observations from our test dataset. We can visualize these test cases using the `plot_features()` function.
``` {r}
## Create LIME explainer object
explainer <- lime(as.data.frame(x_train), model, quantile_bins = FALSE, n_bins = 2)
## Randomly sample from the training dataset
samples <- sample(1:nrow(x_train), 4)
## Use explainer object to explain subset of data
explanation <-
explain(
as.data.frame(x_test)[samples, ], ## Randomly select 4 sample cases
explainer = explainer,
n_labels = 1, ## Explain a single class (i.e. CVD)
n_features = 3 ## Return the top five features critical to the case
)
## Plot the features that correlate LOCALLY for the single case using a bar chart
plot_features(explanation) +
labs(
title = "LIME: Feature Importance Visualization",
subtitle = "Hold Out (Test) Set, 4 Cases Shown"
)
test_data[samples, ]
```
If we want to evaluate more than just a few observations at a time, we can use the `plot_explanations()` funciton instead. This creates a heatmap of many cases, rather than individual barplots for a select few cases.
``` {r}
## Use explainer object to explain subset of data
explanation <-
explain(
as.data.frame(x_test)[sample(1:nrow(x_train), 20), ], ## Randomly select 20 cases
explainer = explainer,
n_labels = 1, # Explaine a single class (i.e. CVD)
n_features = 3 # Return the top five features critical to EACH case
)
## Plot the features that correlate LOCALLY for ALL cases using a heatmap
plot_explanations(explanation) +
labs(
title = "LIME Feature Importance Heatmap",
subtitle = "Hold Out (Test) Set, First 20 Cases Shown"
)
```