Skip to content

using check_enough_train_data in practice #452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 28, 2025
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.13
Version: 0.1.14
Authors@R: c(
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down
10 changes: 5 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ S3method(apply_frosting,epi_workflow)
S3method(augment,epi_workflow)
S3method(autoplot,canned_epipred)
S3method(autoplot,epi_workflow)
S3method(bake,check_enough_train_data)
S3method(bake,check_enough_data)
S3method(bake,epi_recipe)
S3method(bake,step_adjust_latency)
S3method(bake,step_climate)
Expand Down Expand Up @@ -49,7 +49,7 @@ S3method(key_colnames,recipe)
S3method(mean,quantile_pred)
S3method(predict,epi_workflow)
S3method(predict,flatline)
S3method(prep,check_enough_train_data)
S3method(prep,check_enough_data)
S3method(prep,epi_recipe)
S3method(prep,step_adjust_latency)
S3method(prep,step_climate)
Expand All @@ -65,7 +65,7 @@ S3method(print,arx_class)
S3method(print,arx_fcast)
S3method(print,canned_epipred)
S3method(print,cdc_baseline_fcast)
S3method(print,check_enough_train_data)
S3method(print,check_enough_data)
S3method(print,climate_fcast)
S3method(print,epi_recipe)
S3method(print,epi_workflow)
Expand Down Expand Up @@ -109,7 +109,7 @@ S3method(slather,layer_threshold)
S3method(slather,layer_unnest)
S3method(snap,default)
S3method(snap,quantile_pred)
S3method(tidy,check_enough_train_data)
S3method(tidy,check_enough_data)
S3method(tidy,frosting)
S3method(tidy,layer)
S3method(update,layer)
Expand Down Expand Up @@ -142,7 +142,7 @@ export(autoplot)
export(bake)
export(cdc_baseline_args_list)
export(cdc_baseline_forecaster)
export(check_enough_train_data)
export(check_enough_data)
export(clean_f_name)
export(climate_args_list)
export(climatological_forecaster)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
- Removes dependence on the `distributional` package, replacing the quantiles
with `hardhat::quantile_pred()`. Some associated functions are deprecated with
`lifecycle` messages.
- Rename `check_enough_train_data()` to `check_enough_data()`, and generalize it
enough to use as a check on either training or testing.
- Add check for enough data to predict in `arx_forecaster()`

## Improvements

Expand All @@ -33,6 +36,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
- Add `climatological_forecaster()` to automatically create climate baselines
- Replace `dist_quantiles()` with `hardhat::quantile_pred()`
- Allow `quantile()` to threshold to an interval if desired (#434)
- `arx_forecaster()` detects if there's enough data to predict

## Bug fixes

Expand Down
2 changes: 1 addition & 1 deletion R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ arx_class_epi_workflow <- function(
step_training_window(n_recent = args_list$n_training)

if (!is.null(args_list$check_enough_data_n)) {
r <- check_enough_train_data(
r <- check_enough_data(
r,
recipes::all_predictors(),
recipes::all_outcomes(),
Expand Down
10 changes: 6 additions & 4 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,14 @@ arx_fcast_epi_workflow <- function(
step_epi_ahead(!!outcome, ahead = args_list$ahead)
r <- r %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training)
step_training_window(n_recent = args_list$n_training) %>%
check_enough_data(all_predictors(), min_observations = 1, skip = FALSE)

if (!is.null(args_list$check_enough_data_n)) {
r <- r %>% check_enough_train_data(
r <- r %>% check_enough_data(
all_predictors(),
!!outcome,
n = args_list$check_enough_data_n,
all_outcomes(),
min_observations = args_list$check_enough_data_n,
epi_keys = args_list$check_enough_data_epi_keys,
drop_na = FALSE
)
Expand Down
2 changes: 1 addition & 1 deletion R/canned-epipred.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ print.canned_epipred <- function(x, name, ...) {
"At forecast date{?s}: {.val {fds}},",
"For target date{?s}: {.val {tds}},"
))
if ("actions" %in% names(x$pre) && "recipe" %in% names(x$pre$actions)) {
if ("pre" %in% names(x) && "actions" %in% names(x$pre) && "recipe" %in% names(x$pre$actions)) {
fit_recipe <- extract_recipe(x$epi_workflow)
if (detect_step(fit_recipe, "adjust_latency")) {
is_adj_latency <- map_lgl(fit_recipe$steps, function(x) inherits(x, "step_adjust_latency"))
Expand Down
193 changes: 193 additions & 0 deletions R/check_enough_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#' Check the dataset contains enough data points.
#'
#' `check_enough_data` creates a *specification* of a recipe
#' operation that will check if variables contain enough data.
#'
#' @param recipe A recipe object. The check will be added to the
#' sequence of operations for this recipe.
#' @param ... One or more selector functions to choose variables for this check.
#' See [selections()] for more details. You will usually want to use
#' [recipes::all_predictors()] and/or [recipes::all_outcomes()] here.
#' @param min_observations The minimum number of data points required for
#' training. If this is NULL, the total number of predictors will be used.
#' @param epi_keys A character vector of column names on which to group the data
#' and check threshold within each group. Useful if your forecaster trains
#' per group (for example, per geo_value).
#' @param drop_na A logical for whether to count NA values as valid rows.
#' @param role Not used by this check since no new variables are
#' created.
#' @param trained A logical for whether the selectors in `...`
#' have been resolved by [prep()].
#' @param id A character string that is unique to this check to identify it.
#' @param skip A logical. If `TRUE`, only training data is checked, while if
#' `FALSE`, both training and predicting data is checked. Technically, this
#' answers the question "should the check be skipped when the recipe is baked
#' by [bake()]?" While all operations are baked when [prep()] is run, some
#' operations may not be able to be conducted on new data (e.g. processing the
#' outcome variable(s)). Care should be taken when using `skip = TRUE` as it
#' may affect the computations for subsequent operations.
#' @family checks
#' @export
#' @details This check will break the `prep` and/or bake function if any of the
#' checked columns have not enough non-NA values. If the check passes, nothing
#' is changed in the data. It is best used after every other step.
#'
#' For checking training data, it is best to set `...` to be
#' `all_predictors(), all_outcomes()`, while for checking prediction data, it
#' is best to set `...` to be `all_predictors()` only, with `n = 1`.
#'
#' # tidy() results
#'
#' When you [`tidy()`][tidy.recipe()] this check, a tibble with column
#' `terms` (the selectors or variables selected) is returned.
#'
check_enough_data <-
function(recipe,
...,
min_observations = NULL,
epi_keys = NULL,
drop_na = TRUE,
role = NA,
trained = FALSE,
skip = TRUE,
id = rand_id("enough_data")) {
recipes::add_check(
recipe,
check_enough_data_new(
min_observations = min_observations,
epi_keys = epi_keys,
drop_na = drop_na,
terms = enquos(...),
role = role,
trained = trained,
columns = NULL,
skip = skip,
id = id
)
)
}

check_enough_data_new <-
function(min_observations, epi_keys, drop_na, terms,
role, trained, columns, skip, id) {
recipes::check(
subclass = "enough_data",
prefix = "check_",
min_observations = min_observations,
epi_keys = epi_keys,
drop_na = drop_na,
terms = terms,
role = role,
trained = trained,
columns = columns,
skip = skip,
id = id
)
}

#' @export
prep.check_enough_data <- function(x, training, info = NULL, ...) {
col_names <- recipes::recipes_eval_select(x$terms, training, info)
if (is.null(x$min_observations)) {
x$min_observations <- length(col_names)
}

check_enough_data_core(training, x, col_names, "train")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed at prep time? Is it enough to only have it process during bake? Prep runs only during model training. Bake runs at train time and at test time (unless skip = TRUE).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the one advantage of having it in both is that it allows us to print The following columns don't have enough data to train: x and y. if it breaks during prep and The following columns don't have enough data to predict: x and y. if it breaks during prediction. It ends up only running once either way, and the tests all pass with it removed from prep.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine to leave it, but I'm not convinced it does what you want. Is there a corner case in which it can pass at prep time but fail at bake time? In that case, model training would error, but the message would say something about predicting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess one could substitute a different dataset between baking and prepping, but substituting a dataset with less data seems unlikely.


check_enough_data_new(
min_observations = x$min_observations,
epi_keys = x$epi_keys,
drop_na = x$drop_na,
terms = x$terms,
role = x$role,
trained = TRUE,
columns = col_names,
skip = x$skip,
id = x$id
)
}

#' @export
bake.check_enough_data <- function(object, new_data, ...) {
col_names <- object$columns
check_enough_data_core(new_data, object, col_names, "predict")
new_data
}

#' @export
print.check_enough_data <- function(x, width = max(20, options()$width - 30), ...) {
title <- paste0("Check enough data (n = ", x$min_observations, ") for ")
recipes::print_step(x$columns, x$terms, x$trained, title, width)
invisible(x)
}

#' @export
tidy.check_enough_data <- function(x, ...) {
if (recipes::is_trained(x)) {
res <- tibble(terms = unname(x$columns))
} else {
res <- tibble(terms = recipes::sel2char(x$terms))
}
res$id <- x$id
res$min_observations <- x$min_observations
res$epi_keys <- x$epi_keys
res$drop_na <- x$drop_na
res
}

check_enough_data_core <- function(epi_df, step_obj, col_names, train_or_predict) {
epi_df <- epi_df %>%
group_by(across(all_of(.env$step_obj$epi_keys)))
if (step_obj$drop_na) {
any_missing_data <- epi_df %>%
mutate(any_are_na = rowSums(across(any_of(.env$col_names), ~ is.na(.x))) > 0) %>%
# count the number of rows where they're all not na
summarise(sum(any_are_na == 0) < .env$step_obj$min_observations, .groups = "drop")
any_missing_data <- any_missing_data %>%
summarize(across(all_of(setdiff(names(any_missing_data), step_obj$epi_keys)), any)) %>%
any()

# figuring out which individual columns (if any) are to blame for this dearth
# of data
cols_not_enough_data <- epi_df %>%
summarise(
across(
all_of(.env$col_names),
~ sum(!is.na(.x)) < .env$step_obj$min_observations
),
.groups = "drop"
) %>%
# Aggregate across keys (if present)
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
unlist() %>%
# Select the names of the columns that are TRUE
names(.)[.]

# Either all columns have enough data, in which case this message won't be
# sent later or none of the single columns have enough data, that means its
# the combination of all of them.
if (length(cols_not_enough_data) == 0) {
cols_not_enough_data <-
glue::glue("no single column, but the combination of {paste0(col_names, collapse = ', ')}")
}
} else {
# if we're not dropping na values, just count
cols_not_enough_data <- epi_df %>%
summarise(across(all_of(.env$col_names), ~ dplyr::n() < .env$step_obj$min_observations))
any_missing_data <- cols_not_enough_data %>%
summarize(across(all_of(.env$col_names), all)) %>%
all()
cols_not_enough_data <- cols_not_enough_data %>%
summarise(across(all_of(.env$col_names), any), .groups = "drop") %>%
unlist() %>%
# Select the names of the columns that are TRUE
names(.)[.]
}

if (any_missing_data) {
cli_abort(
"The following columns don't have enough data to {train_or_predict}: {cols_not_enough_data}.",
class = "epipredict__not_enough_data"
)
}
}
Loading