Skip to content

Commit f52f11e

Browse files
authored
[R] Allow passing data.frame to SHAP (dmlc#10744)
1 parent ec8cfb3 commit f52f11e

File tree

5 files changed

+101
-7
lines changed

5 files changed

+101
-7
lines changed

R-package/R/xgb.ggplot.R

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,27 @@ xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med
102102
#' @export
103103
xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL,
104104
trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) {
105+
if (inherits(data, "xgb.DMatrix")) {
106+
stop(
107+
"'xgb.ggplot.shap.summary' is not compatible with 'xgb.DMatrix' objects. Try passing a matrix or data.frame."
108+
)
109+
}
110+
cols_categ <- NULL
111+
if (!is.null(model)) {
112+
ftypes <- getinfo(model, "feature_type")
113+
if (NROW(ftypes)) {
114+
if (length(ftypes) != ncol(data)) {
115+
stop(sprintf("'data' has incorrect number of columns (expected: %d, got: %d).", length(ftypes), ncol(data)))
116+
}
117+
cols_categ <- colnames(data)[ftypes == "c"]
118+
}
119+
} else if (inherits(data, "data.frame")) {
120+
cols_categ <- names(data)[sapply(data, function(x) is.factor(x) || is.character(x))]
121+
}
122+
if (NROW(cols_categ)) {
123+
warning("Categorical features are ignored in 'xgb.ggplot.shap.summary'.")
124+
}
125+
105126
data_list <- xgb.shap.data(
106127
data = data,
107128
shap_contrib = shap_contrib,
@@ -114,6 +135,10 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL,
114135
subsample = subsample,
115136
max_observations = 10000 # 10,000 samples per feature.
116137
)
138+
if (NROW(cols_categ)) {
139+
data_list <- lapply(data_list, function(x) x[, !(colnames(x) %in% cols_categ), drop = FALSE])
140+
}
141+
117142
p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE)
118143
# Reverse factor levels so that the first level is at the top of the plot
119144
p_data[, "feature" := factor(feature, rev(levels(feature)))]
@@ -134,7 +159,8 @@ xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL,
134159
#' @param data_list The result of `xgb.shap.data()`.
135160
#' @param normalize Whether to standardize feature values to mean 0 and
136161
#' standard deviation 1. This is useful for comparing multiple features on the same
137-
#' plot. Default is `FALSE`.
162+
#' plot. Default is `FALSE`. Note that it cannot be used when the data contains
163+
#' categorical features.
138164
#' @return A `data.table` containing the observation ID, the feature name, the
139165
#' feature value (normalized if specified), and the SHAP contribution value.
140166
#' @noRd

R-package/R/xgb.plot.shap.R

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#'
33
#' Visualizes SHAP values against feature values to gain an impression of feature effects.
44
#'
5-
#' @param data The data to explain as a `matrix` or `dgCMatrix`.
5+
#' @param data The data to explain as a `matrix`, `dgCMatrix`, or `data.frame`.
66
#' @param shap_contrib Matrix of SHAP contributions of `data`.
77
#' The default (`NULL`) computes it from `model` and `data`.
88
#' @param features Vector of column indices or feature names to plot. When `NULL`
@@ -285,8 +285,11 @@ xgb.plot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, to
285285
xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL,
286286
trees = NULL, target_class = NULL, approxcontrib = FALSE,
287287
subsample = NULL, max_observations = 100000) {
288-
if (!is.matrix(data) && !inherits(data, "dgCMatrix"))
289-
stop("data: must be either matrix or dgCMatrix")
288+
if (!inherits(data, c("matrix", "dsparseMatrix", "data.frame")))
289+
stop("data: must be matrix, sparse matrix, or data.frame.")
290+
if (inherits(data, "data.frame") && length(class(data)) > 1L) {
291+
data <- as.data.frame(data)
292+
}
290293

291294
if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
292295
stop("when shap_contrib is not provided, one must provide an xgb.Booster model")
@@ -311,7 +314,14 @@ xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
311314
stop("if model has no feature_names, columns in `data` must match features in model")
312315

313316
if (!is.null(subsample)) {
314-
idx <- sample(x = seq_len(nrow(data)), size = as.integer(subsample * nrow(data)), replace = FALSE)
317+
if (subsample <= 0 || subsample >= 1) {
318+
stop("'subsample' must be a number between zero and one (non-inclusive).")
319+
}
320+
sample_size <- as.integer(subsample * nrow(data))
321+
if (sample_size < 2) {
322+
stop("Sampling fraction involves less than 2 rows.")
323+
}
324+
idx <- sample(x = seq_len(nrow(data)), size = sample_size, replace = FALSE)
315325
} else {
316326
idx <- seq_len(min(nrow(data), max_observations))
317327
}

R-package/man/xgb.plot.shap.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R-package/man/xgb.plot.shap.summary.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R-package/tests/testthat/test_helpers.R

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,26 @@ test_that("xgb.shap.data works with subsampling", {
449449
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
450450
})
451451

452+
test_that("xgb.shap.data works with data frames", {
453+
data(mtcars)
454+
df <- mtcars
455+
df$cyl <- factor(df$cyl)
456+
x <- df[, -1]
457+
y <- df$mpg
458+
dm <- xgb.DMatrix(x, label = y, nthread = 1L)
459+
model <- xgb.train(
460+
data = dm,
461+
params = list(
462+
max_depth = 2,
463+
nthread = 1
464+
),
465+
nrounds = 2
466+
)
467+
data_list <- xgb.shap.data(data = df[, -1], model = model, top_n = 2, subsample = 0.8)
468+
expect_equal(NROW(data_list$data), as.integer(0.8 * nrow(df)))
469+
expect_equal(NROW(data_list$data), NROW(data_list$shap_contrib))
470+
})
471+
452472
test_that("prepare.ggplot.shap.data works", {
453473
.skip_if_vcd_not_available()
454474
data_list <- xgb.shap.data(data = sparse_matrix, model = bst.Tree, top_n = 2)
@@ -472,6 +492,44 @@ test_that("xgb.plot.shap.summary works", {
472492
expect_silent(xgb.ggplot.shap.summary(data = sparse_matrix, model = bst.Tree, top_n = 2))
473493
})
474494

495+
test_that("xgb.plot.shap.summary ignores categorical features", {
496+
.skip_if_vcd_not_available()
497+
data(mtcars)
498+
df <- mtcars
499+
df$cyl <- factor(df$cyl)
500+
levels(df$cyl) <- c("a", "b", "c")
501+
x <- df[, -1]
502+
y <- df$mpg
503+
dm <- xgb.DMatrix(x, label = y, nthread = 1L)
504+
model <- xgb.train(
505+
data = dm,
506+
params = list(
507+
max_depth = 2,
508+
nthread = 1
509+
),
510+
nrounds = 2
511+
)
512+
expect_warning({
513+
xgb.ggplot.shap.summary(data = x, model = model, top_n = 2)
514+
})
515+
516+
x_num <- mtcars[, -1]
517+
x_num$gear <- as.numeric(x_num$gear) - 1
518+
x_num <- as.matrix(x_num)
519+
dm <- xgb.DMatrix(x_num, label = y, feature_types = c(rep("q", 8), "c", "q"), nthread = 1L)
520+
model <- xgb.train(
521+
data = dm,
522+
params = list(
523+
max_depth = 2,
524+
nthread = 1
525+
),
526+
nrounds = 2
527+
)
528+
expect_warning({
529+
xgb.ggplot.shap.summary(data = x_num, model = model, top_n = 2)
530+
})
531+
})
532+
475533
test_that("check.deprecation works", {
476534
ttt <- function(a = NNULL, DUMMY = NULL, ...) {
477535
check.deprecation(...)

0 commit comments

Comments
 (0)