Skip to content

Commit b4d4071

Browse files
authored
Merge pull request #373 from cmu-delphi/352-remove-all-instances-of-epi_keys
352 remove all instances of epi keys
2 parents c77ea78 + 6d8edc0 commit b4d4071

File tree

97 files changed

+853
-1095
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+853
-1095
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.0.19
3+
Version: 0.0.20
44
Authors@R: c(
55
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),

NAMESPACE

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@ S3method(bake,step_population_scaling)
2828
S3method(bake,step_training_window)
2929
S3method(detect_layer,frosting)
3030
S3method(detect_layer,workflow)
31-
S3method(epi_keys,data.frame)
32-
S3method(epi_keys,default)
33-
S3method(epi_keys,epi_df)
34-
S3method(epi_keys,epi_workflow)
35-
S3method(epi_keys,recipe)
3631
S3method(epi_recipe,default)
3732
S3method(epi_recipe,epi_df)
3833
S3method(epi_recipe,formula)
@@ -55,6 +50,8 @@ S3method(forecast,epi_workflow)
5550
S3method(format,dist_quantiles)
5651
S3method(is.na,dist_quantiles)
5752
S3method(is.na,distribution)
53+
S3method(key_colnames,epi_workflow)
54+
S3method(key_colnames,recipe)
5855
S3method(mean,dist_quantiles)
5956
S3method(median,dist_quantiles)
6057
S3method(predict,epi_workflow)
@@ -154,7 +151,6 @@ export(clean_f_name)
154151
export(default_epi_recipe_blueprint)
155152
export(detect_layer)
156153
export(dist_quantiles)
157-
export(epi_keys)
158154
export(epi_recipe)
159155
export(epi_recipe_blueprint)
160156
export(epi_workflow)
@@ -170,7 +166,6 @@ export(flusight_hub_formatter)
170166
export(forecast)
171167
export(frosting)
172168
export(get_test_data)
173-
export(grab_names)
174169
export(is_epi_recipe)
175170
export(is_epi_workflow)
176171
export(is_layer)
@@ -194,6 +189,7 @@ export(pivot_quantiles_longer)
194189
export(pivot_quantiles_wider)
195190
export(prep)
196191
export(quantile_reg)
192+
export(rand_id)
197193
export(remove_epi_recipe)
198194
export(remove_frosting)
199195
export(remove_model)
@@ -207,6 +203,8 @@ export(step_growth_rate)
207203
export(step_lag_difference)
208204
export(step_population_scaling)
209205
export(step_training_window)
206+
export(tibble)
207+
export(tidy)
210208
export(update_epi_recipe)
211209
export(update_frosting)
212210
export(update_model)
@@ -229,30 +227,50 @@ importFrom(checkmate,assert_number)
229227
importFrom(checkmate,assert_numeric)
230228
importFrom(checkmate,assert_scalar)
231229
importFrom(cli,cli_abort)
230+
importFrom(cli,cli_warn)
232231
importFrom(dplyr,across)
233232
importFrom(dplyr,all_of)
233+
importFrom(dplyr,any_of)
234+
importFrom(dplyr,arrange)
234235
importFrom(dplyr,bind_cols)
236+
importFrom(dplyr,bind_rows)
237+
importFrom(dplyr,everything)
238+
importFrom(dplyr,filter)
239+
importFrom(dplyr,full_join)
235240
importFrom(dplyr,group_by)
236-
importFrom(dplyr,n)
241+
importFrom(dplyr,left_join)
242+
importFrom(dplyr,mutate)
243+
importFrom(dplyr,relocate)
244+
importFrom(dplyr,rename)
245+
importFrom(dplyr,select)
237246
importFrom(dplyr,summarise)
247+
importFrom(dplyr,summarize)
238248
importFrom(dplyr,ungroup)
239249
importFrom(epiprocess,epi_slide)
240250
importFrom(epiprocess,growth_rate)
241251
importFrom(generics,augment)
242252
importFrom(generics,fit)
243253
importFrom(generics,forecast)
254+
importFrom(generics,tidy)
255+
importFrom(ggplot2,aes)
244256
importFrom(ggplot2,autoplot)
257+
importFrom(ggplot2,geom_line)
258+
importFrom(ggplot2,geom_linerange)
259+
importFrom(ggplot2,geom_point)
260+
importFrom(ggplot2,geom_ribbon)
245261
importFrom(hardhat,refresh_blueprint)
246262
importFrom(hardhat,run_mold)
247263
importFrom(magrittr,"%>%")
248264
importFrom(recipes,bake)
249265
importFrom(recipes,prep)
266+
importFrom(recipes,rand_id)
250267
importFrom(rlang,"!!!")
251268
importFrom(rlang,"!!")
252269
importFrom(rlang,"%@%")
253270
importFrom(rlang,"%||%")
254271
importFrom(rlang,":=")
255272
importFrom(rlang,abort)
273+
importFrom(rlang,arg_match)
256274
importFrom(rlang,as_function)
257275
importFrom(rlang,caller_env)
258276
importFrom(rlang,enquo)
@@ -264,6 +282,7 @@ importFrom(rlang,is_logical)
264282
importFrom(rlang,is_null)
265283
importFrom(rlang,is_true)
266284
importFrom(rlang,set_names)
285+
importFrom(rlang,sym)
267286
importFrom(stats,as.formula)
268287
importFrom(stats,family)
269288
importFrom(stats,lm)
@@ -274,9 +293,9 @@ importFrom(stats,predict)
274293
importFrom(stats,qnorm)
275294
importFrom(stats,quantile)
276295
importFrom(stats,residuals)
296+
importFrom(tibble,as_tibble)
277297
importFrom(tibble,tibble)
278298
importFrom(tidyr,crossing)
279-
importFrom(tidyr,drop_na)
280299
importFrom(vctrs,as_list_of)
281300
importFrom(vctrs,field)
282301
importFrom(vctrs,new_rcrd)

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,4 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
5656
- add functionality to calculate weighted interval scores for `dist_quantiles()`
5757
- Add `step_epi_slide` to produce generic sliding computations over an `epi_df`
5858
- Add quantile random forests (via `{grf}`) as a parsnip engine
59+
- Replace `epi_keys()` with `epiprocess::key_colnames()`, #352

R/arx_classifier.R

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626
#' @seealso [arx_class_epi_workflow()], [arx_class_args_list()]
2727
#'
2828
#' @examples
29+
#' library(dplyr)
2930
#' jhu <- case_death_rate_subset %>%
30-
#' dplyr::filter(time_value >= as.Date("2021-11-01"))
31+
#' filter(time_value >= as.Date("2021-11-01"))
3132
#'
3233
#' out <- arx_classifier(jhu, "death_rate", c("case_rate", "death_rate"))
3334
#'
@@ -45,23 +46,23 @@ arx_classifier <- function(
4546
epi_data,
4647
outcome,
4748
predictors,
48-
trainer = parsnip::logistic_reg(),
49+
trainer = logistic_reg(),
4950
args_list = arx_class_args_list()) {
5051
if (!is_classification(trainer)) {
51-
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
52+
cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
5253
}
5354

5455
wf <- arx_class_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
55-
wf <- generics::fit(wf, epi_data)
56+
wf <- fit(wf, epi_data)
5657

5758
preds <- forecast(
5859
wf,
5960
fill_locf = TRUE,
6061
n_recent = args_list$nafill_buffer,
6162
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
6263
) %>%
63-
tibble::as_tibble() %>%
64-
dplyr::select(-time_value)
64+
as_tibble() %>%
65+
select(-time_value)
6566

6667
structure(
6768
list(
@@ -95,17 +96,17 @@ arx_classifier <- function(
9596
#' @export
9697
#' @seealso [arx_classifier()]
9798
#' @examples
98-
#'
99+
#' library(dplyr)
99100
#' jhu <- case_death_rate_subset %>%
100-
#' dplyr::filter(time_value >= as.Date("2021-11-01"))
101+
#' filter(time_value >= as.Date("2021-11-01"))
101102
#'
102103
#' arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate"))
103104
#'
104105
#' arx_class_epi_workflow(
105106
#' jhu,
106107
#' "death_rate",
107108
#' c("case_rate", "death_rate"),
108-
#' trainer = parsnip::multinom_reg(),
109+
#' trainer = multinom_reg(),
109110
#' args_list = arx_class_args_list(
110111
#' breaks = c(-.05, .1), ahead = 14,
111112
#' horizon = 14, method = "linear_reg"
@@ -119,18 +120,18 @@ arx_class_epi_workflow <- function(
119120
args_list = arx_class_args_list()) {
120121
validate_forecaster_inputs(epi_data, outcome, predictors)
121122
if (!inherits(args_list, c("arx_class", "alist"))) {
122-
rlang::abort("args_list was not created using `arx_class_args_list().")
123+
cli_abort("`args_list` was not created using `arx_class_args_list()`.")
123124
}
124125
if (!(is.null(trainer) || is_classification(trainer))) {
125-
rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.")
126+
cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
126127
}
127128
lags <- arx_lags_validator(predictors, args_list$lags)
128129

129130
# --- preprocessor
130131
# ------- predictors
131132
r <- epi_recipe(epi_data) %>%
132133
step_growth_rate(
133-
tidyselect::all_of(predictors),
134+
dplyr::all_of(predictors),
134135
role = "grp",
135136
horizon = args_list$horizon,
136137
method = args_list$method,
@@ -173,26 +174,24 @@ arx_class_epi_workflow <- function(
173174
o2 <- rlang::sym(paste0("ahead_", args_list$ahead, "_", o))
174175
r <- r %>%
175176
step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome") %>%
176-
step_mutate(
177+
recipes::step_mutate(
177178
outcome_class = cut(!!o2, breaks = args_list$breaks),
178179
role = "outcome"
179180
) %>%
180181
step_epi_naomit() %>%
181-
step_training_window(n_recent = args_list$n_training) %>%
182-
{
183-
if (!is.null(args_list$check_enough_data_n)) {
184-
check_enough_train_data(
185-
.,
186-
all_predictors(),
187-
!!outcome,
188-
n = args_list$check_enough_data_n,
189-
epi_keys = args_list$check_enough_data_epi_keys,
190-
drop_na = FALSE
191-
)
192-
} else {
193-
.
194-
}
195-
}
182+
step_training_window(n_recent = args_list$n_training)
183+
184+
if (!is.null(args_list$check_enough_data_n)) {
185+
r <- check_enough_train_data(
186+
r,
187+
recipes::all_predictors(),
188+
recipes::all_outcomes(),
189+
n = args_list$check_enough_data_n,
190+
epi_keys = args_list$check_enough_data_epi_keys,
191+
drop_na = FALSE
192+
)
193+
}
194+
196195

197196
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
198197
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
@@ -264,7 +263,7 @@ arx_class_args_list <- function(
264263
outcome_transform = c("growth_rate", "lag_difference"),
265264
breaks = 0.25,
266265
horizon = 7L,
267-
method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"),
266+
method = c("rel_change", "linear_reg"),
268267
log_scale = FALSE,
269268
additional_gr_args = list(),
270269
nafill_buffer = Inf,
@@ -274,8 +273,8 @@ arx_class_args_list <- function(
274273
rlang::check_dots_empty()
275274
.lags <- lags
276275
if (is.list(lags)) lags <- unlist(lags)
277-
method <- match.arg(method)
278-
outcome_transform <- match.arg(outcome_transform)
276+
method <- rlang::arg_match(method)
277+
outcome_transform <- rlang::arg_match(outcome_transform)
279278

280279
arg_is_scalar(ahead, n_training, horizon, log_scale)
281280
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
@@ -287,12 +286,11 @@ arx_class_args_list <- function(
287286
if (is.finite(n_training)) arg_is_pos_int(n_training)
288287
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
289288
if (!is.list(additional_gr_args)) {
290-
cli::cli_abort(
291-
c("`additional_gr_args` must be a {.cls list}.",
292-
"!" = "This is a {.cls {class(additional_gr_args)}}.",
293-
i = "See `?epiprocess::growth_rate` for available arguments."
294-
)
295-
)
289+
cli_abort(c(
290+
"`additional_gr_args` must be a {.cls list}.",
291+
"!" = "This is a {.cls {class(additional_gr_args)}}.",
292+
i = "See `?epiprocess::growth_rate` for available arguments."
293+
))
296294
}
297295
arg_is_pos(check_enough_data_n, allow_null = TRUE)
298296
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)

0 commit comments

Comments
 (0)