diff --git a/NAMESPACE b/NAMESPACE index 2a4e9893..4c6c744c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -50,6 +50,7 @@ S3method(relative_eff,array) S3method(relative_eff,default) S3method(relative_eff,matrix) S3method(relative_eff,psis) +S3method(subset,psis) S3method(update,psis_loo_ss) S3method(waic,"function") S3method(waic,array) @@ -101,6 +102,7 @@ export(psis_n_eff_values) export(psislw) export(relative_eff) export(stacking_weights) +export(subset.psis) export(waic) export(waic.array) export(waic.function) diff --git a/R/psis.R b/R/psis.R index 48a7cfb7..94e51d37 100644 --- a/R/psis.R +++ b/R/psis.R @@ -83,9 +83,9 @@ #' uw <- weights(psis_result, log=FALSE, normalize = FALSE) # unnormalized weights #' #' -#' psis <- function(log_ratios, ...) UseMethod("psis") + #' @export #' @templateVar fn psis #' @template array @@ -102,6 +102,7 @@ psis.array <- do_psis(log_ratios, r_eff = r_eff, cores = cores) } + #' @export #' @templateVar fn psis #' @template matrix @@ -117,6 +118,7 @@ psis.matrix <- do_psis(log_ratios, r_eff = r_eff, cores = cores) } + #' @export #' @templateVar fn psis #' @template vector @@ -129,12 +131,12 @@ psis.default <- psis.matrix(log_ratios, r_eff = r_eff, cores = 1) } + #' @rdname psis #' @export #' @export weights.psis #' @method weights psis -#' @param object For the `weights()` method, an object returned by `psis()` (a -#' list with class `"psis"`). +#' @param object,x An object returned by `psis()`. #' @param log For the `weights()` method, should the weights be returned on #' the log scale? Defaults to `TRUE`. #' @param normalize For the `weights()` method, should the weights be @@ -165,19 +167,64 @@ weights.psis <- } +# Subset a psis object without breaking it +# +#' @rdname psis +#' @export +#' @export subset.psis +#' @method subset psis +#' @param subset For the `subset()` method, a vector indicating which +#' observations (columns of weights) to keep. Can be a logical vector of +#' length `ncol(x)` (for a psis object `x`) or a shorter integer vector +#' containing only the indexes to keep. +#' +#' @return The `subset()` returns a `"psis"` object. It is the same as the input +#' but without the contents corresponding to the unselected indexes. +#' +subset.psis <- function(x, subset, ...) { + if (anyNA(subset)) { + stop("NAs not allowed in subset.", call. = FALSE) + } + if (is.logical(subset) || all(subset %in% c(0,1))) { + stopifnot(length(subset) == dim(x)[2]) + subset <- which(as.logical(subset)) + } else { + stopifnot(length(subset) <= dim(x)[2], + all(subset == as.integer(subset))) + subset <- as.integer(subset) + } + + x$log_weights <- x$log_weights[, subset, drop=FALSE] + x$diagnostics$pareto_k <- x$diagnostics$pareto_k[subset] + x$diagnostics$n_eff <- x$diagnostics$n_eff[subset] + + structure( + .Data = x, + class = class(x), + dims = c(dim(x)[1], length(subset)), + norm_const_log = attr(x, "norm_const_log")[subset], + tail_len = attr(x, "tail_len")[subset], + r_eff = attr(x, "r_eff")[subset], + subset = subset + ) +} + + +#' @rdname psis #' @export dim.psis <- function(x) { attr(x, "dims") } + #' @rdname psis #' @export -#' @param x For `is.psis()`, an object to check. is.psis <- function(x) { inherits(x, "psis") && is.list(x) } + # internal ---------------------------------------------------------------- #' Structure the object returned by the psis methods diff --git a/man/psis.Rd b/man/psis.Rd index 57f5b542..48f4cbc8 100644 --- a/man/psis.Rd +++ b/man/psis.Rd @@ -6,6 +6,8 @@ \alias{psis.matrix} \alias{psis.default} \alias{weights.psis} +\alias{subset.psis} +\alias{dim.psis} \alias{is.psis} \title{Pareto smoothed importance sampling (PSIS)} \usage{ @@ -21,6 +23,10 @@ psis(log_ratios, ...) \method{weights}{psis}(object, ..., log = TRUE, normalize = TRUE) +\method{subset}{psis}(x, subset, ...) + +\method{dim}{psis}(x) + is.psis(x) } \arguments{ @@ -57,8 +63,7 @@ the \code{.Rprofile} file to set \code{mc.cores} (using the \code{cores} argumen setting \code{mc.cores} interactively or in a script is fine). }} -\item{object}{For the \code{weights()} method, an object returned by \code{psis()} (a -list with class \code{"psis"}).} +\item{object, x}{An object returned by \code{psis()}.} \item{log}{For the \code{weights()} method, should the weights be returned on the log scale? Defaults to \code{TRUE}.} @@ -66,7 +71,10 @@ the log scale? Defaults to \code{TRUE}.} \item{normalize}{For the \code{weights()} method, should the weights be normalized? Defaults to \code{TRUE}.} -\item{x}{For \code{is.psis()}, an object to check.} +\item{subset}{For the \code{subset()} method, a vector indicating which +observations (columns of weights) to keep. Can be a logical vector of +length \code{ncol(x)} (for a psis object \code{x}) or a shorter integer vector +containing only the indexes to keep.} } \value{ The \code{psis()} methods return an object of class \code{"psis"}, @@ -111,6 +119,9 @@ The \code{weights()} method returns an object with the same dimensions as the \code{log_weights} component of the \code{"psis"} object. The \code{normalize} and \code{log} arguments control whether the returned weights are normalized and whether or not to return them on the log scale. + +The \code{subset()} method returns the input \code{"psis"} object but dropping +the contents corresponding to the unselected indexes. } \description{ Implementation of Pareto smoothed importance sampling (PSIS), a method for @@ -146,7 +157,6 @@ w <- weights(psis_result, log=FALSE) # normalized weights (not log-weights) uw <- weights(psis_result, log=FALSE, normalize = FALSE) # unnormalized weights - } \references{ Vehtari, A., Gelman, A., and Gabry, J. (2017a). Practical Bayesian model diff --git a/tests/testthat/test_psis.R b/tests/testthat/test_psis.R index ae847a66..eb396a16 100644 --- a/tests/testthat/test_psis.R +++ b/tests/testthat/test_psis.R @@ -115,6 +115,34 @@ test_that("weights method returns correct output", { }) +test_that("subset method works correctly", { + dims <- dim(psis1) + a1 <- subset(psis1, subset = rep_len(c(TRUE, FALSE), dims[2])) + a2 <- subset(psis1, subset = seq(1, dims[2], by = 2)) + expect_identical(a1, a2) # logical subsetting same as specifying indexes + + a3 <- subset(psis1, subset = c(1, 4, 20)) + expect_equal(a3$log_weights, psis1$log_weights[, c(1, 4, 20), drop=FALSE]) + expect_equal(attr(a3, "tail_len"), attr(psis1, "tail_len")[c(1, 4, 20)]) + expect_equal(attr(a3, "subset"), c(1, 4, 20)) + + expect_error( + subset(psis1, subset = c(TRUE, FALSE)), + "length(subset) == dim(x)[2] is not TRUE", + fixed = TRUE + ) + expect_error( + subset(psis1, subset = seq_len(dim(psis1)[2] + 1)), + "length(subset) <= dim(x)[2] is not TRUE", + fixed = TRUE + ) + expect_error( + subset(psis1, subset = c(1, NA, 3)), + "NAs not allowed in subset" + ) +}) + + test_that("psis_n_eff methods works properly", { w <- weights(psis1, normalize = TRUE, log = FALSE) expect_equal(psis_n_eff.default(w[, 1], r_eff = 1), 1 / sum(w[, 1]^2))