diff --git a/R/step_epi_slide.R b/R/step_epi_slide.R index 9714971fa..2254e9554 100644 --- a/R/step_epi_slide.R +++ b/R/step_epi_slide.R @@ -44,8 +44,8 @@ step_epi_slide <- function(recipe, ..., .f, - before = 0L, - after = 0L, + before = NULL, + after = NULL, role = "predictor", prefix = "epi_slide_", f_name = clean_f_name(.f), @@ -55,8 +55,12 @@ step_epi_slide <- cli_abort("This recipe step can only operate on an {.cls epi_recipe}.") } .f <- validate_slide_fun(.f) - epiprocess:::validate_slide_window_arg(before, attributes(recipe$template)$metadata$time_type) - epiprocess:::validate_slide_window_arg(after, attributes(recipe$template)$metadata$time_type) + if (!is.null(before)) { + epiprocess:::validate_slide_window_arg(before, attributes(recipe$template)$metadata$time_type) + } + if (!is.null(after)) { + epiprocess:::validate_slide_window_arg(after, attributes(recipe$template)$metadata$time_type) + } arg_is_chr_scalar(role, prefix, id) arg_is_lgl_scalar(skip) @@ -136,7 +140,6 @@ prep.step_epi_slide <- function(x, training, info = NULL, ...) { #' @export bake.step_epi_slide <- function(object, new_data, ...) { - recipes::check_new_data(names(object$columns), object, new_data) col_names <- object$columns name_prefix <- paste0(object$prefix, object$f_name, "_") newnames <- glue::glue("{name_prefix}{col_names}") @@ -153,6 +156,10 @@ bake.step_epi_slide <- function(object, new_data, ...) { class = "epipredict__step__name_collision_error" ) } + # make sure that new_data is actually an epi_df + if (!inherits(new_data, "epi_df")) { + new_data <- new_data %>% as_epi_df() + } # TODO: Uncomment this whenever we make the optimized versions available. # if (any(vapply(c(mean, sum), \(x) identical(x, object$.f), logical(1L)))) { # cli_warn( diff --git a/man/step_epi_slide.Rd b/man/step_epi_slide.Rd index 46bb386ad..d6559e081 100644 --- a/man/step_epi_slide.Rd +++ b/man/step_epi_slide.Rd @@ -8,8 +8,8 @@ step_epi_slide( recipe, ..., .f, - before = 0L, - after = 0L, + before = NULL, + after = NULL, role = "predictor", prefix = "epi_slide_", f_name = clean_f_name(.f), diff --git a/tests/testthat/_snaps/step_epi_slide.md b/tests/testthat/_snaps/step_epi_slide.md new file mode 100644 index 000000000..ea392d0b1 --- /dev/null +++ b/tests/testthat/_snaps/step_epi_slide.md @@ -0,0 +1,25 @@ +# epi_slide works on weekly data with one of before/ahead set + + Code + baked + Output + An `epi_df` object, 40 x 4 with metadata: + * geo_type = state + * time_type = week + * as_of = 1999-09-09 + + # A tibble: 40 x 4 + geo_value time_value value epi_slide__.f_value + * + 1 ca 2022-01-01 2 2 + 2 ca 2022-01-08 3 2.5 + 3 ca 2022-01-15 4 3 + 4 ca 2022-01-22 5 3.5 + 5 ca 2022-01-29 6 4.5 + 6 ca 2022-02-05 7 5.5 + 7 ca 2022-02-12 8 6.5 + 8 ca 2022-02-19 9 7.5 + 9 ca 2022-02-26 10 8.5 + 10 ca 2022-03-05 11 9.5 + # i 30 more rows + diff --git a/tests/testthat/test-step_epi_slide.R b/tests/testthat/test-step_epi_slide.R index 29e046eae..dff7a8a4b 100644 --- a/tests/testthat/test-step_epi_slide.R +++ b/tests/testthat/test-step_epi_slide.R @@ -8,6 +8,14 @@ edf <- data.frame( ) %>% as_epi_df() +tt_week <- seq(as.Date("2022-01-01"), by = "1 week", length.out = 20) +edf_weekly <- data.frame( + time_value = c(tt_week, tt_week), + geo_value = rep(c("ca", "ny"), each = 20L), + value = c(2:21, 3:22) +) %>% + as_epi_df() + r <- epi_recipe(edf) rolled_before <- edf %>% group_by(geo_value) %>% @@ -73,3 +81,14 @@ test_that("epi_slide handles different function specs", { expect_equal(blfun[[4]], rolled_before) expect_equal(nblfun[[4]], rolled_before) }) + +test_that("epi_slide works on weekly data with one of before/ahead set", { + expect_no_error( + baked <- epi_recipe(edf_weekly) %>% + step_epi_slide(value, .f = "mean", before = as.difftime(3, units = "weeks")) %>% + prep(edf_weekly) %>% + bake(new_data = NULL) + ) + attributes(baked)$metadata$as_of <- as.Date("1999-09-09") + expect_snapshot(baked) +})