From 2d6cc017d9ec69419d0ce0e565399f01ca7daf85 Mon Sep 17 00:00:00 2001 From: jgabry Date: Thu, 30 May 2019 14:30:53 -0400 Subject: [PATCH 1/3] subset.psis method --- NAMESPACE | 1 + R/psis.R | 48 ++++++++++++++++++++++++++++++++++++++++++++---- man/psis.Rd | 17 ++++++++++++++--- 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index e3fd09b8..98c980ae 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -39,6 +39,7 @@ S3method(relative_eff,array) S3method(relative_eff,default) S3method(relative_eff,matrix) S3method(relative_eff,psis) +S3method(subset,psis) S3method(waic,"function") S3method(waic,array) S3method(waic,matrix) diff --git a/R/psis.R b/R/psis.R index 48a7cfb7..dca1925d 100644 --- a/R/psis.R +++ b/R/psis.R @@ -83,9 +83,9 @@ #' uw <- weights(psis_result, log=FALSE, normalize = FALSE) # unnormalized weights #' #' -#' psis <- function(log_ratios, ...) UseMethod("psis") + #' @export #' @templateVar fn psis #' @template array @@ -102,6 +102,7 @@ psis.array <- do_psis(log_ratios, r_eff = r_eff, cores = cores) } + #' @export #' @templateVar fn psis #' @template matrix @@ -117,6 +118,7 @@ psis.matrix <- do_psis(log_ratios, r_eff = r_eff, cores = cores) } + #' @export #' @templateVar fn psis #' @template vector @@ -129,12 +131,12 @@ psis.default <- psis.matrix(log_ratios, r_eff = r_eff, cores = 1) } + #' @rdname psis #' @export #' @export weights.psis #' @method weights psis -#' @param object For the `weights()` method, an object returned by `psis()` (a -#' list with class `"psis"`). +#' @param object,x An object returned by `psis()`. #' @param log For the `weights()` method, should the weights be returned on #' the log scale? Defaults to `TRUE`. #' @param normalize For the `weights()` method, should the weights be @@ -165,19 +167,57 @@ weights.psis <- } +# Subset a psis object without breaking it +# +#' @rdname psis +#' @export +#' @param subset For the `subset()` method, a vector indicating which +#' observations (columns of weights) to keep. Can be a logical vector of +#' length `ncol(x)` (for a psis object `x`) or a shorter integer vector +#' containing only the indexes to keep. +#' +#' @return The `subset()` method returns a valid `"psis"` object +subset.psis <- function(x, subset, ...) { + if (anyNA(subset)) { + stop("NAs not allowed in subset.", call. = FALSE) + } + if (is.logical(subset) || all(subset %in% c(0,1))) { + stopifnot(length(subset) == dim(x)[2]) + subset <- which(as.logical(subset)) + } else { + stopifnot(length(subset) <= dim(x)[2], + all(subset == as.integer(subset))) + subset <- as.integer(subset) + } + + x$log_weights <- x$log_weights[, subset, drop=FALSE] + x$diagnostics$pareto_k <- x$diagnostics$pareto_k[subset] + x$diagnostics$n_eff <- x$diagnostics$n_eff[subset] + + attr(x, "dims") <- c(dim(x)[1], length(subset)) + attr(x, "norm_const_log") <- attr(x, "norm_const_log")[subset] + attr(x, "tail_len") <- attr(x, "tail_len")[subset] + attr(x, "r_eff") <- attr(x, "r_eff")[subset] + + x +} + + +#' @rdname psis #' @export dim.psis <- function(x) { attr(x, "dims") } + #' @rdname psis #' @export -#' @param x For `is.psis()`, an object to check. is.psis <- function(x) { inherits(x, "psis") && is.list(x) } + # internal ---------------------------------------------------------------- #' Structure the object returned by the psis methods diff --git a/man/psis.Rd b/man/psis.Rd index 57f5b542..72263391 100644 --- a/man/psis.Rd +++ b/man/psis.Rd @@ -6,6 +6,8 @@ \alias{psis.matrix} \alias{psis.default} \alias{weights.psis} +\alias{subset.psis} +\alias{dim.psis} \alias{is.psis} \title{Pareto smoothed importance sampling (PSIS)} \usage{ @@ -21,6 +23,10 @@ psis(log_ratios, ...) \method{weights}{psis}(object, ..., log = TRUE, normalize = TRUE) +\method{subset}{psis}(x, subset, ...) + +\method{dim}{psis}(x) + is.psis(x) } \arguments{ @@ -57,8 +63,7 @@ the \code{.Rprofile} file to set \code{mc.cores} (using the \code{cores} argumen setting \code{mc.cores} interactively or in a script is fine). }} -\item{object}{For the \code{weights()} method, an object returned by \code{psis()} (a -list with class \code{"psis"}).} +\item{object, x}{An object returned by \code{psis()}.} \item{log}{For the \code{weights()} method, should the weights be returned on the log scale? Defaults to \code{TRUE}.} @@ -67,6 +72,11 @@ the log scale? Defaults to \code{TRUE}.} normalized? Defaults to \code{TRUE}.} \item{x}{For \code{is.psis()}, an object to check.} + +\item{subset}{For the \code{subset()} method, a vector indicating which +observations (columns of weights) to keep. Can be a logical vector of +length \code{ncol(x)} (for a psis object \code{x}) or a shorter integer vector +containing only the indexes to keep.} } \value{ The \code{psis()} methods return an object of class \code{"psis"}, @@ -111,6 +121,8 @@ The \code{weights()} method returns an object with the same dimensions as the \code{log_weights} component of the \code{"psis"} object. The \code{normalize} and \code{log} arguments control whether the returned weights are normalized and whether or not to return them on the log scale. + +The \code{subset()} method returns a valid \code{"psis"} object } \description{ Implementation of Pareto smoothed importance sampling (PSIS), a method for @@ -146,7 +158,6 @@ w <- weights(psis_result, log=FALSE) # normalized weights (not log-weights) uw <- weights(psis_result, log=FALSE, normalize = FALSE) # unnormalized weights - } \references{ Vehtari, A., Gelman, A., and Gabry, J. (2017a). Practical Bayesian model From b43667f5787a1f3ff7a053c3fd0bdc70d0a8652d Mon Sep 17 00:00:00 2001 From: jgabry Date: Mon, 3 Jun 2019 19:26:17 -0400 Subject: [PATCH 2/3] test subset.psis --- NAMESPACE | 1 + R/psis.R | 21 ++++++++++++++------- man/psis.Rd | 5 ++--- tests/testthat/test_psis.R | 23 +++++++++++++++++++++++ 4 files changed, 40 insertions(+), 10 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 98c980ae..1026858f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -80,6 +80,7 @@ export(psis_n_eff_values) export(psislw) export(relative_eff) export(stacking_weights) +export(subset.psis) export(waic) export(waic.array) export(waic.function) diff --git a/R/psis.R b/R/psis.R index dca1925d..94e51d37 100644 --- a/R/psis.R +++ b/R/psis.R @@ -171,12 +171,16 @@ weights.psis <- # #' @rdname psis #' @export +#' @export subset.psis +#' @method subset psis #' @param subset For the `subset()` method, a vector indicating which #' observations (columns of weights) to keep. Can be a logical vector of #' length `ncol(x)` (for a psis object `x`) or a shorter integer vector #' containing only the indexes to keep. #' -#' @return The `subset()` method returns a valid `"psis"` object +#' @return The `subset()` returns a `"psis"` object. It is the same as the input +#' but without the contents corresponding to the unselected indexes. +#' subset.psis <- function(x, subset, ...) { if (anyNA(subset)) { stop("NAs not allowed in subset.", call. = FALSE) @@ -194,12 +198,15 @@ subset.psis <- function(x, subset, ...) { x$diagnostics$pareto_k <- x$diagnostics$pareto_k[subset] x$diagnostics$n_eff <- x$diagnostics$n_eff[subset] - attr(x, "dims") <- c(dim(x)[1], length(subset)) - attr(x, "norm_const_log") <- attr(x, "norm_const_log")[subset] - attr(x, "tail_len") <- attr(x, "tail_len")[subset] - attr(x, "r_eff") <- attr(x, "r_eff")[subset] - - x + structure( + .Data = x, + class = class(x), + dims = c(dim(x)[1], length(subset)), + norm_const_log = attr(x, "norm_const_log")[subset], + tail_len = attr(x, "tail_len")[subset], + r_eff = attr(x, "r_eff")[subset], + subset = subset + ) } diff --git a/man/psis.Rd b/man/psis.Rd index 72263391..48f4cbc8 100644 --- a/man/psis.Rd +++ b/man/psis.Rd @@ -71,8 +71,6 @@ the log scale? Defaults to \code{TRUE}.} \item{normalize}{For the \code{weights()} method, should the weights be normalized? Defaults to \code{TRUE}.} -\item{x}{For \code{is.psis()}, an object to check.} - \item{subset}{For the \code{subset()} method, a vector indicating which observations (columns of weights) to keep. Can be a logical vector of length \code{ncol(x)} (for a psis object \code{x}) or a shorter integer vector @@ -122,7 +120,8 @@ the \code{log_weights} component of the \code{"psis"} object. The \code{normaliz \code{log} arguments control whether the returned weights are normalized and whether or not to return them on the log scale. -The \code{subset()} method returns a valid \code{"psis"} object +The \code{subset()} method returns the input \code{"psis"} object but dropping +the contents corresponding to the unselected indexes. } \description{ Implementation of Pareto smoothed importance sampling (PSIS), a method for diff --git a/tests/testthat/test_psis.R b/tests/testthat/test_psis.R index ae847a66..e42c2bdc 100644 --- a/tests/testthat/test_psis.R +++ b/tests/testthat/test_psis.R @@ -115,6 +115,29 @@ test_that("weights method returns correct output", { }) +test_that("subset method works correctly", { + dims <- dim(psis1) + a1 <- subset(psis1, subset = rep_len(c(TRUE, FALSE), dims[2])) + a2 <- subset(psis1, subset = seq(1, dims[2], by = 2)) + expect_identical(a1, a2) # logical subsetting same as specifying indexes + + a3 <- subset(psis1, subset = c(1, 4, 20)) + expect_equal(a3$log_weights, psis1$log_weights[, c(1, 4, 20), drop=FALSE]) + expect_equal(attr(a3, "tail_len"), attr(psis1, "tail_len")[c(1, 4, 20)]) + expect_equal(attr(a3, "subset"), c(1, 4, 20)) + expect_error( + subset(psis1, subset = c(TRUE, FALSE)), + "length(subset) == dim(x)[2] is not TRUE", + fixed = TRUE + ) + expect_error( + subset(psis1, subset = seq_len(dim(psis1)[2] + 1)), + "length(subset) <= dim(x)[2] is not TRUE", + fixed = TRUE + ) +}) + + test_that("psis_n_eff methods works properly", { w <- weights(psis1, normalize = TRUE, log = FALSE) expect_equal(psis_n_eff.default(w[, 1], r_eff = 1), 1 / sum(w[, 1]^2)) From 44a9312d8c26c094830893daad25222204af37eb Mon Sep 17 00:00:00 2001 From: jgabry Date: Mon, 3 Jun 2019 19:57:01 -0400 Subject: [PATCH 3/3] test for error from NAs --- tests/testthat/test_psis.R | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/testthat/test_psis.R b/tests/testthat/test_psis.R index e42c2bdc..eb396a16 100644 --- a/tests/testthat/test_psis.R +++ b/tests/testthat/test_psis.R @@ -125,6 +125,7 @@ test_that("subset method works correctly", { expect_equal(a3$log_weights, psis1$log_weights[, c(1, 4, 20), drop=FALSE]) expect_equal(attr(a3, "tail_len"), attr(psis1, "tail_len")[c(1, 4, 20)]) expect_equal(attr(a3, "subset"), c(1, 4, 20)) + expect_error( subset(psis1, subset = c(TRUE, FALSE)), "length(subset) == dim(x)[2] is not TRUE", @@ -135,6 +136,10 @@ test_that("subset method works correctly", { "length(subset) <= dim(x)[2] is not TRUE", fixed = TRUE ) + expect_error( + subset(psis1, subset = c(1, NA, 3)), + "NAs not allowed in subset" + ) })