Skip to content

Commit 4acbeb0

Browse files
authored
Merge pull request #370 from cmu-delphi/330-epi_recipe
Deprecate `epi_recipe()` in favour of `recipe()`
2 parents f76961c + 36034fd commit 4acbeb0

File tree

101 files changed

+590
-882
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+590
-882
lines changed

DESCRIPTION

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.0.20
3+
Version: 0.1.0
44
Authors@R: c(
55
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
@@ -23,7 +23,7 @@ URL: https://github.com/cmu-delphi/epipredict/,
2323
https://cmu-delphi.github.io/epipredict
2424
BugReports: https://github.com/cmu-delphi/epipredict/issues/
2525
Depends:
26-
epiprocess (>= 0.7.5),
26+
epiprocess (>= 0.7.12),
2727
parsnip (>= 1.0.0),
2828
R (>= 3.5.0)
2929
Imports:

NAMESPACE

+4-5
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ S3method(bake,step_training_window)
2929
S3method(detect_layer,frosting)
3030
S3method(detect_layer,workflow)
3131
S3method(epi_recipe,default)
32-
S3method(epi_recipe,epi_df)
33-
S3method(epi_recipe,formula)
3432
S3method(extract_argument,epi_workflow)
3533
S3method(extract_argument,frosting)
3634
S3method(extract_argument,layer)
@@ -96,6 +94,8 @@ S3method(print,step_naomit)
9694
S3method(print,step_population_scaling)
9795
S3method(print,step_training_window)
9896
S3method(quantile,dist_quantiles)
97+
S3method(recipe,epi_df)
98+
S3method(recipes::recipe,formula)
9999
S3method(refresh_blueprint,default_epi_recipe_blueprint)
100100
S3method(residuals,flatline)
101101
S3method(run_mold,default_epi_recipe_blueprint)
@@ -152,7 +152,6 @@ export(default_epi_recipe_blueprint)
152152
export(detect_layer)
153153
export(dist_quantiles)
154154
export(epi_recipe)
155-
export(epi_recipe_blueprint)
156155
export(epi_workflow)
157156
export(extract_argument)
158157
export(extract_frosting)
@@ -183,13 +182,12 @@ export(layer_residual_quantiles)
183182
export(layer_threshold)
184183
export(layer_unnest)
185184
export(nested_quantiles)
186-
export(new_default_epi_recipe_blueprint)
187-
export(new_epi_recipe_blueprint)
188185
export(pivot_quantiles_longer)
189186
export(pivot_quantiles_wider)
190187
export(prep)
191188
export(quantile_reg)
192189
export(rand_id)
190+
export(recipe)
193191
export(remove_epi_recipe)
194192
export(remove_frosting)
195193
export(remove_model)
@@ -264,6 +262,7 @@ importFrom(magrittr,"%>%")
264262
importFrom(recipes,bake)
265263
importFrom(recipes,prep)
266264
importFrom(recipes,rand_id)
265+
importFrom(recipes,recipe)
267266
importFrom(rlang,"!!!")
268267
importFrom(rlang,"!!")
269268
importFrom(rlang,"%@%")

R/arx_classifier.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#' be real-valued. Conversion of this data to unordered classes is handled
1111
#' internally based on the `breaks` argument to [arx_class_args_list()].
1212
#' If discrete classes are already in the `epi_df`, it is recommended to
13-
#' code up a classifier from scratch using [epi_recipe()].
13+
#' code up a classifier from scratch using [recipe()].
1414
#' @param trainer A `{parsnip}` model describing the type of estimation.
1515
#' For now, we enforce `mode = "classification"`. Typical values are
1616
#' [parsnip::logistic_reg()] or [parsnip::multinom_reg()]. More complicated
@@ -129,7 +129,7 @@ arx_class_epi_workflow <- function(
129129

130130
# --- preprocessor
131131
# ------- predictors
132-
r <- epi_recipe(epi_data) %>%
132+
r <- recipe(epi_data) %>%
133133
step_growth_rate(
134134
dplyr::all_of(predictors),
135135
role = "grp",

R/arx_forecaster.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ arx_fcast_epi_workflow <- function(
122122
lags <- arx_lags_validator(predictors, args_list$lags)
123123

124124
# --- preprocessor
125-
r <- epi_recipe(epi_data)
125+
r <- recipe(epi_data)
126126
for (l in seq_along(lags)) {
127127
p <- predictors[l]
128128
r <- step_epi_lag(r, !!p, lag = lags[[l]])

R/autoplot.R

+6-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ ggplot2::autoplot
3232
#' jhu <- case_death_rate_subset %>%
3333
#' filter(time_value >= as.Date("2021-11-01"))
3434
#'
35-
#' r <- epi_recipe(jhu) %>%
35+
#' r <- recipe(jhu) %>%
3636
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
3737
#' step_epi_ahead(death_rate, ahead = 7) %>%
3838
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
@@ -56,7 +56,7 @@ ggplot2::autoplot
5656
#' # ------- Show multiple horizons
5757
#'
5858
#' p <- lapply(c(7, 14, 21, 28), function(h) {
59-
#' r <- epi_recipe(jhu) %>%
59+
#' r <- recipe(jhu) %>%
6060
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
6161
#' step_epi_ahead(death_rate, ahead = h) %>%
6262
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
@@ -184,7 +184,10 @@ autoplot.epi_workflow <- function(
184184
}
185185

186186
if (".pred" %in% names(predictions)) {
187-
ntarget_dates <- n_distinct(predictions$time_value)
187+
ntarget_dates <- dplyr::n_distinct(predictions$time_value)
188+
if (distributional::is_distribution(predictions$.pred)) {
189+
predictions <- dplyr::mutate(predictions, .pred = median(.pred))
190+
}
188191
if (ntarget_dates > 1L) {
189192
bp <- bp +
190193
geom_line(

R/blueprint-epi_recipe-default.R

+51-93
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,69 @@
1-
#' Recipe blueprint that accounts for `epi_df` panel data
2-
#'
3-
#' Used for simplicity. See [hardhat::new_recipe_blueprint()] or
4-
#' [hardhat::default_recipe_blueprint()] for more details.
5-
#'
6-
#' @inheritParams hardhat::new_recipe_blueprint
1+
#' Default epi_recipe blueprint
72
#'
8-
#' @details The `bake_dependent_roles` are automatically set to `epi_df` defaults.
9-
#' @return A recipe blueprint.
3+
#' Recipe blueprint that accounts for `epi_df` panel data
4+
#' Used for simplicity. See [hardhat::default_recipe_blueprint()] for more
5+
#' details. This subclass is nearly the same, except it ensures that
6+
#' downstream processing doesn't drop the epi_df class from the data.
107
#'
11-
#' @keywords internal
8+
#' @inheritParams hardhat::default_recipe_blueprint
9+
#' @return A `epi_recipe` blueprint.
1210
#' @export
13-
new_epi_recipe_blueprint <-
14-
function(intercept = FALSE, allow_novel_levels = FALSE, fresh = TRUE,
15-
composition = "tibble",
16-
ptypes = NULL, recipe = NULL, ..., subclass = character()) {
17-
hardhat::new_recipe_blueprint(
18-
intercept = intercept,
19-
allow_novel_levels = allow_novel_levels,
20-
fresh = fresh,
21-
composition = composition,
22-
ptypes = ptypes,
23-
recipe = recipe,
24-
...,
25-
subclass = c(subclass, "epi_recipe_blueprint")
26-
)
27-
}
28-
29-
30-
#' @rdname new_epi_recipe_blueprint
31-
#' @export
32-
epi_recipe_blueprint <-
33-
function(intercept = FALSE, allow_novel_levels = FALSE,
34-
fresh = TRUE,
35-
composition = "tibble") {
36-
new_epi_recipe_blueprint(
37-
intercept = intercept,
38-
allow_novel_levels = allow_novel_levels,
39-
fresh = fresh,
40-
composition = composition
41-
)
42-
}
11+
#' @keywords internal
12+
default_epi_recipe_blueprint <- function(intercept = FALSE,
13+
allow_novel_levels = FALSE,
14+
fresh = TRUE,
15+
strings_as_factors = FALSE,
16+
composition = "tibble") {
17+
new_default_epi_recipe_blueprint(
18+
intercept = intercept,
19+
allow_novel_levels = allow_novel_levels,
20+
fresh = fresh,
21+
strings_as_factors = strings_as_factors,
22+
composition = composition
23+
)
24+
}
4325

44-
#' @rdname new_epi_recipe_blueprint
45-
#' @export
46-
default_epi_recipe_blueprint <-
47-
function(intercept = FALSE, allow_novel_levels = FALSE, fresh = TRUE,
48-
composition = "tibble") {
49-
new_default_epi_recipe_blueprint(
50-
intercept = intercept,
51-
allow_novel_levels = allow_novel_levels,
52-
fresh = fresh,
53-
composition = composition
54-
)
55-
}
26+
new_default_epi_recipe_blueprint <- function(intercept = FALSE,
27+
allow_novel_levels = TRUE,
28+
fresh = TRUE,
29+
strings_as_factors = FALSE,
30+
composition = "tibble",
31+
ptypes = NULL,
32+
recipe = NULL,
33+
extra_role_ptypes = NULL,
34+
...,
35+
subclass = character()) {
36+
hardhat::new_recipe_blueprint(
37+
intercept = intercept,
38+
allow_novel_levels = allow_novel_levels,
39+
fresh = fresh,
40+
strings_as_factors = strings_as_factors,
41+
composition = composition,
42+
ptypes = ptypes,
43+
recipe = recipe,
44+
extra_role_ptypes = extra_role_ptypes,
45+
...,
46+
subclass = c(subclass, "default_epi_recipe_blueprint", "default_recipe_blueprint")
47+
)
48+
}
5649

57-
#' @rdname new_epi_recipe_blueprint
58-
#' @inheritParams hardhat::new_default_recipe_blueprint
59-
#' @export
60-
new_default_epi_recipe_blueprint <-
61-
function(intercept = FALSE, allow_novel_levels = FALSE,
62-
fresh = TRUE,
63-
composition = "tibble", ptypes = NULL, recipe = NULL,
64-
extra_role_ptypes = NULL, ..., subclass = character()) {
65-
new_epi_recipe_blueprint(
66-
intercept = intercept,
67-
allow_novel_levels = allow_novel_levels,
68-
fresh = fresh,
69-
composition = composition,
70-
ptypes = ptypes,
71-
recipe = recipe,
72-
extra_role_ptypes = extra_role_ptypes,
73-
...,
74-
subclass = c(subclass, "default_epi_recipe_blueprint", "default_recipe_blueprint")
75-
)
76-
}
7750

7851
#' @importFrom hardhat run_mold
7952
#' @export
8053
run_mold.default_epi_recipe_blueprint <- function(blueprint, ..., data) {
8154
rlang::check_dots_empty0(...)
82-
# blueprint <- hardhat:::patch_recipe_default_blueprint(blueprint)
83-
cleaned <- mold_epi_recipe_default_clean(blueprint = blueprint, data = data)
84-
blueprint <- cleaned$blueprint
85-
data <- cleaned$data
55+
# we don't do the "cleaning" in `hardhat:::run_mold.default_recipe_blueprint`
56+
# That function drops the epi_df class without any recourse.
57+
# The only way we should be here at all is if `data` is an epi_df, but just
58+
# in case...
59+
if (!is_epi_df(data)) {
60+
cli_warn("`data` is not an {.cls epi_df}. It has class {.cls {class(data)}}.")
61+
}
8662
hardhat:::mold_recipe_default_process(blueprint = blueprint, data = data)
8763
}
8864

89-
mold_epi_recipe_default_clean <- function(blueprint, data) {
90-
hardhat:::check_data_frame_or_matrix(data)
91-
if (!is_epi_df(data)) data <- hardhat:::coerce_to_tibble(data)
92-
hardhat:::new_mold_clean(blueprint, data)
93-
}
94-
9565
#' @importFrom hardhat refresh_blueprint
9666
#' @export
9767
refresh_blueprint.default_epi_recipe_blueprint <- function(blueprint) {
9868
do.call(new_default_epi_recipe_blueprint, as.list(blueprint))
9969
}
100-
101-
102-
## removing this function?
103-
# er_check_is_data_like <- function(.x, .x_nm) {
104-
# if (rlang::is_missing(.x_nm)) {
105-
# .x_nm <- rlang::as_label(rlang::enexpr(.x))
106-
# }
107-
# if (!hardhat:::is_new_data_like(.x)) {
108-
# hardhat:::glubort("`{.x_nm}` must be a data.frame or a matrix, not a {class1(.x)}.")
109-
# }
110-
# .x
111-
# }

R/cdc_baseline_forecaster.R

+21-21
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,25 @@
3636
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
3737
#' preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn)
3838
#'
39-
#' if (require(ggplot2)) {
40-
#' forecast_date <- unique(preds$forecast_date)
41-
#' four_states <- c("ca", "pa", "wa", "ny")
42-
#' preds %>%
43-
#' filter(geo_value %in% four_states) %>%
44-
#' ggplot(aes(target_date)) +
45-
#' geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = blues9[3]) +
46-
#' geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = blues9[6]) +
47-
#' geom_line(aes(y = .pred), color = "orange") +
48-
#' geom_line(
49-
#' data = weekly_deaths %>% filter(geo_value %in% four_states),
50-
#' aes(x = time_value, y = deaths)
51-
#' ) +
52-
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
53-
#' labs(x = "Date", y = "Weekly deaths") +
54-
#' facet_wrap(~geo_value, scales = "free_y") +
55-
#' theme_bw() +
56-
#' geom_vline(xintercept = forecast_date)
57-
#' }
39+
#' library(ggplot2)
40+
#' forecast_date <- unique(preds$forecast_date)
41+
#' four_states <- c("ca", "pa", "wa", "ny")
42+
#' preds %>%
43+
#' filter(geo_value %in% four_states) %>%
44+
#' ggplot(aes(target_date)) +
45+
#' geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = blues9[3]) +
46+
#' geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = blues9[6]) +
47+
#' geom_line(aes(y = .pred), color = "orange") +
48+
#' geom_line(
49+
#' data = weekly_deaths %>% filter(geo_value %in% four_states),
50+
#' aes(x = time_value, y = deaths)
51+
#' ) +
52+
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
53+
#' labs(x = "Date", y = "Weekly deaths") +
54+
#' facet_wrap(~geo_value, scales = "free_y") +
55+
#' theme_bw() +
56+
#' geom_vline(xintercept = forecast_date)
57+
#'
5858
cdc_baseline_forecaster <- function(
5959
epi_data,
6060
outcome,
@@ -68,7 +68,7 @@ cdc_baseline_forecaster <- function(
6868
outcome <- rlang::sym(outcome)
6969

7070

71-
r <- epi_recipe(epi_data) %>%
71+
r <- recipe(epi_data) %>%
7272
step_epi_ahead(!!outcome, ahead = args_list$data_frequency, skip = TRUE) %>%
7373
recipes::update_role(!!outcome, new_role = "predictor") %>%
7474
recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>%
@@ -79,7 +79,7 @@ cdc_baseline_forecaster <- function(
7979

8080

8181
latest <- get_test_data(
82-
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
82+
recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
8383
forecast_date
8484
)
8585

0 commit comments

Comments
 (0)