diff --git a/NAMESPACE b/NAMESPACE index 8431aee1..d948fca5 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,6 @@ # Generated by roxygen2: do not edit by hand +S3method(calc_impl,"function") S3method(calc_impl,Chisq) S3method(calc_impl,F) S3method(calc_impl,correlation) diff --git a/NEWS.md b/NEWS.md index dfb6a2b9..0655a8cb 100755 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,13 @@ # infer (development version) +* Introduced support for arbitrary test statistics in `calculate()`. In addition + to the pre-implemented `calculate(stat)` options, taken as strings, users can + now supply a function defining any scalar-valued test statistic. See + `?calculate()` to learn more. + * Added missing commas and addressed formatting issues throughout the vignettes and articles. Backticks for package names were removed and missing parentheses for functions were added (@Joscelinrocha). + # infer 1.0.7 * The aliases `p_value()` and `conf_int()`, first deprecated 6 years ago, now diff --git a/R/calculate.R b/R/calculate.R index 9bcff31a..2b9aa8af 100755 --- a/R/calculate.R +++ b/R/calculate.R @@ -12,13 +12,15 @@ #' #' @param x The output from [generate()] for computation-based inference or the #' output from [hypothesize()] piped in to here for theory-based inference. -#' @param stat A string giving the type of the statistic to calculate. Current +#' @param stat A string giving the type of the statistic to calculate or a +#' function that takes in a replicate of `x` and returns a scalar value. Current #' options include `"mean"`, `"median"`, `"sum"`, `"sd"`, `"prop"`, `"count"`, #' `"diff in means"`, `"diff in medians"`, `"diff in props"`, `"Chisq"` (or #' `"chisq"`), `"F"` (or `"f"`), `"t"`, `"z"`, `"ratio of props"`, `"slope"`, #' `"odds ratio"`, `"ratio of means"`, and `"correlation"`. `infer` only #' supports theoretical tests on one or two means via the `"t"` distribution -#' and one or two proportions via the `"z"`. +#' and one or two proportions via the `"z"`. See the "Arbitrary test statistics" +#' section below for more on how to define a custom statistic. #' @param order A string vector of specifying the order in which the levels of #' the explanatory variable should be ordered for subtraction (or division #' for ratio-based statistics), where `order = c("first", "second")` means @@ -31,6 +33,38 @@ #' #' @return A tibble containing a `stat` column of calculated statistics. #' +#' @section Arbitrary test statistics: +#' +#' In addition to the pre-implemented statistics documented in `stat`, users can +#' supply an arbitrary test statistic by supplying a function to the `stat` +#' argument. +#' +#' The function should have arguments `stat(x, order, ...)`, where `x` is one +#' replicate's worth of `x`. The `order` argument and ellipses will be supplied +#' directly to the `stat` function. Internally, `calculate()` will split `x` up +#' into data frames by replicate and pass them one-by-one to the supplied `stat`. +#' For example, to implement `stat = "mean"` as a function, one could write: +#' +#' ```r +#' stat_mean <- function(x, order, ...) {mean(x$hours)} +#' obs_mean <- +#' gss %>% +#' specify(response = hours) %>% +#' calculate(stat = stat_mean) +#' +#' set.seed(1) +#' null_dist_mean <- +#' gss %>% +#' specify(response = hours) %>% +#' hypothesize(null = "point", mu = 40) %>% +#' generate(reps = 5, type = "bootstrap") %>% +#' calculate(stat = stat_mean) +#' ``` +#' +#' Note that the same `stat_mean` function is supplied to both `generate()`d and +#' non-`generate()`d infer objects--no need to implement support for grouping +#' by `replicate` yourself. +#' #' @section Missing levels in small samples: #' In some cases, when bootstrapping with small samples, some generated #' bootstrap samples will have only one level of the explanatory variable @@ -97,22 +131,23 @@ calculate <- function(x, ...) { check_type(x, tibble::is_tibble) check_if_mlr(x, "calculate") - stat <- check_calculate_stat(stat) - check_input_vs_stat(x, stat) - check_point_params(x, stat) + stat_chr <- stat_chr(stat) + stat_chr <- check_calculate_stat(stat_chr) + check_input_vs_stat(x, stat_chr) + check_point_params(x, stat_chr) - order <- check_order(x, order, in_calculate = TRUE, stat) + order <- check_order(x, order, in_calculate = TRUE, stat_chr) if (!is_generated(x)) { x$replicate <- 1L } - x <- message_on_excessive_null(x, stat = stat, fn = "calculate") - x <- warn_on_insufficient_null(x, stat, ...) + x <- message_on_excessive_null(x, stat = stat_chr, fn = "calculate") + x <- warn_on_insufficient_null(x, stat_chr, ...) # Use S3 method to match correct calculation result <- calc_impl( - structure(stat, class = gsub(" ", "_", stat)), x, order, ... + structure(stat, class = gsub(" ", "_", stat_chr)), x, order, ... ) result <- copy_attrs(to = result, from = x) @@ -144,9 +179,19 @@ check_if_mlr <- function(x, fn, call = caller_env()) { } } -check_calculate_stat <- function(stat, call = caller_env()) { +stat_chr <- function(stat) { + if (rlang::is_function(stat)) { + return("function") + } + stat +} + +check_calculate_stat <- function(stat, call = caller_env()) { check_type(stat, rlang::is_string, call = call) + if (identical(stat, "function")) { + return(stat) + } # Check for possible `stat` aliases alias_match_id <- match(stat, implemented_stats_aliases[["alias"]]) @@ -178,6 +223,10 @@ check_input_vs_stat <- function(x, stat, call = caller_env()) { ) } + if (identical(stat, "function")) { + return(x) + } + if (!stat %in% possible_stats) { if (has_explanatory(x)) { msg_tail <- glue( @@ -252,7 +301,7 @@ message_on_excessive_null <- function(x, stat = "mean", fn) { warn_on_insufficient_null <- function(x, stat, ...) { if (!is_hypothesized(x) && !has_explanatory(x) && - !stat %in% untheorized_stats && + !stat %in% c(untheorized_stats, "function") && !(stat == "t" && "mu" %in% names(list(...)))) { attr(x, "null") <- "point" attr(x, "params") <- assume_null(x, stat) @@ -626,3 +675,38 @@ calc_impl.z <- function(type, x, order, ...) { df_out } } + +#' @export +calc_impl.function <- function(type, x, order, ..., call = rlang::caller_env()) { + rlang::try_fetch( + { + if (!identical(dplyr::group_vars(x), "replicate")) { + x <- dplyr::group_by(x, replicate) + } + x_by_replicate <- dplyr::group_split(x) + res <- purrr::map(x_by_replicate, ~type(.x, order, ...)) + }, + error = function(cnd) {rethrow_stat_cnd(cnd, call = call)}, + warning = function(cnd) {rethrow_stat_cnd(cnd, call = call)} + ) + + if (!rlang::is_scalar_atomic(res[[1]])) { + cli::cli_abort( + c( + "The supplied {.arg stat} function must return a scalar value.", + "i" = "It returned {.obj_type_friendly {res[[1]]}}." + ), + call = call + ) + } + + tibble::new_tibble(list(stat = unlist(res))) +} + +rethrow_stat_cnd <- function(cnd, call = call) { + cli::cli_abort( + "The supplied {.arg stat} function encountered an issue.", + parent = cnd, + call = call + ) +} diff --git a/R/observe.R b/R/observe.R index cc6e7ddc..d437e5c1 100644 --- a/R/observe.R +++ b/R/observe.R @@ -15,6 +15,8 @@ #' #' @return A 1-column tibble containing the calculated statistic `stat`. #' +#' @inheritSection calculate Arbitrary test statistics +#' #' @examples #' # calculating the observed mean number of hours worked per week #' gss %>% diff --git a/man/calculate.Rd b/man/calculate.Rd index f2739ca9..e2a595f6 100755 --- a/man/calculate.Rd +++ b/man/calculate.Rd @@ -17,13 +17,15 @@ calculate( \item{x}{The output from \code{\link[=generate]{generate()}} for computation-based inference or the output from \code{\link[=hypothesize]{hypothesize()}} piped in to here for theory-based inference.} -\item{stat}{A string giving the type of the statistic to calculate. Current +\item{stat}{A string giving the type of the statistic to calculate or a +function that takes in a replicate of \code{x} and returns a scalar value. Current options include \code{"mean"}, \code{"median"}, \code{"sum"}, \code{"sd"}, \code{"prop"}, \code{"count"}, \code{"diff in means"}, \code{"diff in medians"}, \code{"diff in props"}, \code{"Chisq"} (or \code{"chisq"}), \code{"F"} (or \code{"f"}), \code{"t"}, \code{"z"}, \code{"ratio of props"}, \code{"slope"}, \code{"odds ratio"}, \code{"ratio of means"}, and \code{"correlation"}. \code{infer} only supports theoretical tests on one or two means via the \code{"t"} distribution -and one or two proportions via the \code{"z"}.} +and one or two proportions via the \code{"z"}. See the "Arbitrary test statistics" +section below for more on how to define a custom statistic.} \item{order}{A string vector of specifying the order in which the levels of the explanatory variable should be ordered for subtraction (or division @@ -48,6 +50,39 @@ supplied \code{stat} for each \code{replicate}. Learn more in \code{vignette("infer")}. } +\section{Arbitrary test statistics}{ + + +In addition to the pre-implemented statistics documented in \code{stat}, users can +supply an arbitrary test statistic by supplying a function to the \code{stat} +argument. + +The function should have arguments \code{stat(x, order, ...)}, where \code{x} is one +replicate's worth of \code{x}. The \code{order} argument and ellipses will be supplied +directly to the \code{stat} function. Internally, \code{calculate()} will split \code{x} up +into data frames by replicate and pass them one-by-one to the supplied \code{stat}. +For example, to implement \code{stat = "mean"} as a function, one could write: + +\if{html}{\out{