26
26
# ' @seealso [arx_class_epi_workflow()], [arx_class_args_list()]
27
27
# '
28
28
# ' @examples
29
+ # ' library(dplyr)
29
30
# ' jhu <- case_death_rate_subset %>%
30
- # ' dplyr:: filter(time_value >= as.Date("2021-11-01"))
31
+ # ' filter(time_value >= as.Date("2021-11-01"))
31
32
# '
32
33
# ' out <- arx_classifier(jhu, "death_rate", c("case_rate", "death_rate"))
33
34
# '
@@ -45,23 +46,23 @@ arx_classifier <- function(
45
46
epi_data ,
46
47
outcome ,
47
48
predictors ,
48
- trainer = parsnip :: logistic_reg(),
49
+ trainer = logistic_reg(),
49
50
args_list = arx_class_args_list()) {
50
51
if (! is_classification(trainer )) {
51
- cli :: cli_abort(" `trainer` must be a {.pkg parsnip} model of mode 'classification'." )
52
+ cli_abort(" `trainer` must be a {.pkg parsnip} model of mode 'classification'." )
52
53
}
53
54
54
55
wf <- arx_class_epi_workflow(epi_data , outcome , predictors , trainer , args_list )
55
- wf <- generics :: fit(wf , epi_data )
56
+ wf <- fit(wf , epi_data )
56
57
57
58
preds <- forecast(
58
59
wf ,
59
60
fill_locf = TRUE ,
60
61
n_recent = args_list $ nafill_buffer ,
61
62
forecast_date = args_list $ forecast_date %|| % max(epi_data $ time_value )
62
63
) %> %
63
- tibble :: as_tibble() %> %
64
- dplyr :: select(- time_value )
64
+ as_tibble() %> %
65
+ select(- time_value )
65
66
66
67
structure(
67
68
list (
@@ -95,17 +96,17 @@ arx_classifier <- function(
95
96
# ' @export
96
97
# ' @seealso [arx_classifier()]
97
98
# ' @examples
98
- # '
99
+ # ' library(dplyr)
99
100
# ' jhu <- case_death_rate_subset %>%
100
- # ' dplyr:: filter(time_value >= as.Date("2021-11-01"))
101
+ # ' filter(time_value >= as.Date("2021-11-01"))
101
102
# '
102
103
# ' arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate"))
103
104
# '
104
105
# ' arx_class_epi_workflow(
105
106
# ' jhu,
106
107
# ' "death_rate",
107
108
# ' c("case_rate", "death_rate"),
108
- # ' trainer = parsnip:: multinom_reg(),
109
+ # ' trainer = multinom_reg(),
109
110
# ' args_list = arx_class_args_list(
110
111
# ' breaks = c(-.05, .1), ahead = 14,
111
112
# ' horizon = 14, method = "linear_reg"
@@ -119,18 +120,18 @@ arx_class_epi_workflow <- function(
119
120
args_list = arx_class_args_list()) {
120
121
validate_forecaster_inputs(epi_data , outcome , predictors )
121
122
if (! inherits(args_list , c(" arx_class" , " alist" ))) {
122
- rlang :: abort( " args_list was not created using `arx_class_args_list()." )
123
+ cli_abort( " ` args_list` was not created using `arx_class_args_list()` ." )
123
124
}
124
125
if (! (is.null(trainer ) || is_classification(trainer ))) {
125
- rlang :: abort (" `trainer` must be a `{ parsnip}` model of mode 'classification'." )
126
+ cli_abort (" `trainer` must be a {.pkg parsnip} model of mode 'classification'." )
126
127
}
127
128
lags <- arx_lags_validator(predictors , args_list $ lags )
128
129
129
130
# --- preprocessor
130
131
# ------- predictors
131
132
r <- epi_recipe(epi_data ) %> %
132
133
step_growth_rate(
133
- tidyselect :: all_of(predictors ),
134
+ dplyr :: all_of(predictors ),
134
135
role = " grp" ,
135
136
horizon = args_list $ horizon ,
136
137
method = args_list $ method ,
@@ -173,26 +174,24 @@ arx_class_epi_workflow <- function(
173
174
o2 <- rlang :: sym(paste0(" ahead_" , args_list $ ahead , " _" , o ))
174
175
r <- r %> %
175
176
step_epi_ahead(!! o , ahead = args_list $ ahead , role = " pre-outcome" ) %> %
176
- step_mutate(
177
+ recipes :: step_mutate(
177
178
outcome_class = cut(!! o2 , breaks = args_list $ breaks ),
178
179
role = " outcome"
179
180
) %> %
180
181
step_epi_naomit() %> %
181
- step_training_window(n_recent = args_list $ n_training ) %> %
182
- {
183
- if (! is.null(args_list $ check_enough_data_n )) {
184
- check_enough_train_data(
185
- . ,
186
- all_predictors(),
187
- !! outcome ,
188
- n = args_list $ check_enough_data_n ,
189
- epi_keys = args_list $ check_enough_data_epi_keys ,
190
- drop_na = FALSE
191
- )
192
- } else {
193
- .
194
- }
195
- }
182
+ step_training_window(n_recent = args_list $ n_training )
183
+
184
+ if (! is.null(args_list $ check_enough_data_n )) {
185
+ r <- check_enough_train_data(
186
+ r ,
187
+ recipes :: all_predictors(),
188
+ recipes :: all_outcomes(),
189
+ n = args_list $ check_enough_data_n ,
190
+ epi_keys = args_list $ check_enough_data_epi_keys ,
191
+ drop_na = FALSE
192
+ )
193
+ }
194
+
196
195
197
196
forecast_date <- args_list $ forecast_date %|| % max(epi_data $ time_value )
198
197
target_date <- args_list $ target_date %|| % (forecast_date + args_list $ ahead )
@@ -264,7 +263,7 @@ arx_class_args_list <- function(
264
263
outcome_transform = c(" growth_rate" , " lag_difference" ),
265
264
breaks = 0.25 ,
266
265
horizon = 7L ,
267
- method = c(" rel_change" , " linear_reg" , " smooth_spline " , " trend_filter " ),
266
+ method = c(" rel_change" , " linear_reg" ),
268
267
log_scale = FALSE ,
269
268
additional_gr_args = list (),
270
269
nafill_buffer = Inf ,
@@ -274,8 +273,8 @@ arx_class_args_list <- function(
274
273
rlang :: check_dots_empty()
275
274
.lags <- lags
276
275
if (is.list(lags )) lags <- unlist(lags )
277
- method <- match.arg (method )
278
- outcome_transform <- match.arg (outcome_transform )
276
+ method <- rlang :: arg_match (method )
277
+ outcome_transform <- rlang :: arg_match (outcome_transform )
279
278
280
279
arg_is_scalar(ahead , n_training , horizon , log_scale )
281
280
arg_is_scalar(forecast_date , target_date , allow_null = TRUE )
@@ -287,12 +286,11 @@ arx_class_args_list <- function(
287
286
if (is.finite(n_training )) arg_is_pos_int(n_training )
288
287
if (is.finite(nafill_buffer )) arg_is_pos_int(nafill_buffer , allow_null = TRUE )
289
288
if (! is.list(additional_gr_args )) {
290
- cli :: cli_abort(
291
- c(" `additional_gr_args` must be a {.cls list}." ,
292
- " !" = " This is a {.cls {class(additional_gr_args)}}." ,
293
- i = " See `?epiprocess::growth_rate` for available arguments."
294
- )
295
- )
289
+ cli_abort(c(
290
+ " `additional_gr_args` must be a {.cls list}." ,
291
+ " !" = " This is a {.cls {class(additional_gr_args)}}." ,
292
+ i = " See `?epiprocess::growth_rate` for available arguments."
293
+ ))
296
294
}
297
295
arg_is_pos(check_enough_data_n , allow_null = TRUE )
298
296
arg_is_chr(check_enough_data_epi_keys , allow_null = TRUE )
0 commit comments