@@ -55,11 +55,18 @@ arx_classifier <- function(
55
55
wf <- arx_class_epi_workflow(epi_data , outcome , predictors , trainer , args_list )
56
56
wf <- fit(wf , epi_data )
57
57
58
+ if (args_list $ adjust_latency == " none" ) {
59
+ forecast_date_default <- max(epi_data $ time_value )
60
+ if (! is.null(args_list $ forecast_date ) && args_list $ forecast_date != forecast_date_default ) {
61
+ cli_warn(" The specified forecast date {args_list$forecast_date} doesn't match the date from which the forecast is occurring {forecast_date}." )
62
+ }
63
+ } else {
64
+ forecast_date_default <- attributes(epi_data )$ metadata $ as_of
65
+ }
66
+ forecast_date <- args_list $ forecast_date %|| % forecast_date_default
67
+ target_date <- args_list $ target_date %|| % (forecast_date + args_list $ ahead )
58
68
preds <- forecast(
59
69
wf ,
60
- fill_locf = TRUE ,
61
- n_recent = args_list $ nafill_buffer ,
62
- forecast_date = args_list $ forecast_date %|| % max(epi_data $ time_value )
63
70
) %> %
64
71
as_tibble() %> %
65
72
select(- time_value )
@@ -125,27 +132,39 @@ arx_class_epi_workflow <- function(
125
132
if (! (is.null(trainer ) || is_classification(trainer ))) {
126
133
cli_abort(" `trainer` must be a {.pkg parsnip} model of mode 'classification'." )
127
134
}
135
+
136
+ if (args_list $ adjust_latency == " none" ) {
137
+ forecast_date_default <- max(epi_data $ time_value )
138
+ if (! is.null(args_list $ forecast_date ) && args_list $ forecast_date != forecast_date_default ) {
139
+ cli_warn(" The specified forecast date {args_list$forecast_date} doesn't match the date from which the forecast is occurring {forecast_date}." )
140
+ }
141
+ } else {
142
+ forecast_date_default <- attributes(epi_data )$ metadata $ as_of
143
+ }
144
+ forecast_date <- args_list $ forecast_date %|| % forecast_date_default
145
+ target_date <- args_list $ target_date %|| % (forecast_date + args_list $ ahead )
146
+
128
147
lags <- arx_lags_validator(predictors , args_list $ lags )
129
148
130
149
# --- preprocessor
131
150
# ------- predictors
132
151
r <- epi_recipe(epi_data ) %> %
133
152
step_growth_rate(
134
- dplyr :: all_of(predictors ),
153
+ all_of(predictors ),
135
154
role = " grp" ,
136
155
horizon = args_list $ horizon ,
137
156
method = args_list $ method ,
138
157
log_scale = args_list $ log_scale ,
139
158
additional_gr_args_list = args_list $ additional_gr_args
140
159
)
141
160
for (l in seq_along(lags )) {
142
- p <- predictors [l ]
143
- p <- as.character(glue :: glue_data(args_list , " gr_{horizon}_{method}_{p }" ))
144
- r <- step_epi_lag(r , !! p , lag = lags [[l ]])
161
+ pred_names <- predictors [l ]
162
+ pred_names <- as.character(glue :: glue_data(args_list , " gr_{horizon}_{method}_{pred_names }" ))
163
+ r <- step_epi_lag(r , !! pred_names , lag = lags [[l ]])
145
164
}
146
165
# ------- outcome
147
166
if (args_list $ outcome_transform == " lag_difference" ) {
148
- o <- as.character(
167
+ pre_out_name <- as.character(
149
168
glue :: glue_data(args_list , " lag_diff_{horizon}_{outcome}" )
150
169
)
151
170
r <- r %> %
@@ -156,7 +175,7 @@ arx_class_epi_workflow <- function(
156
175
)
157
176
}
158
177
if (args_list $ outcome_transform == " growth_rate" ) {
159
- o <- as.character(
178
+ pre_out_name <- as.character(
160
179
glue :: glue_data(args_list , " gr_{horizon}_{method}_{outcome}" )
161
180
)
162
181
if (! (outcome %in% predictors )) {
@@ -171,11 +190,30 @@ arx_class_epi_workflow <- function(
171
190
)
172
191
}
173
192
}
174
- o2 <- rlang :: sym(paste0(" ahead_" , args_list $ ahead , " _" , o ))
193
+ # regex that will match any amount of adjustment for the ahead
194
+ ahead_out_name_regex <- glue :: glue(" ahead_[0-9]*_{pre_out_name}" )
195
+ method_adjust_latency <- args_list $ adjust_latency
196
+ if (method_adjust_latency != " none" ) {
197
+ if (method_adjust_latency != " extend_ahead" ) {
198
+ cli_abort(" only extend_ahead is currently supported" ,
199
+ class = " epipredict__arx_classifier__adjust_latency_unsupported_method"
200
+ )
201
+ }
202
+ r <- r %> % step_adjust_latency(!! pre_out_name ,
203
+ fixed_forecast_date = forecast_date ,
204
+ method = method_adjust_latency
205
+ )
206
+ }
175
207
r <- r %> %
176
- step_epi_ahead(!! o , ahead = args_list $ ahead , role = " pre-outcome" ) %> %
177
- recipes :: step_mutate(
178
- outcome_class = cut(!! o2 , breaks = args_list $ breaks ),
208
+ step_epi_ahead(!! pre_out_name , ahead = args_list $ ahead , role = " pre-outcome" )
209
+ r <- r %> %
210
+ step_mutate(
211
+ across(
212
+ matches(ahead_out_name_regex ),
213
+ ~ cut(.x , breaks = args_list $ breaks ),
214
+ .names = " outcome_class" ,
215
+ .unpack = TRUE
216
+ ),
179
217
role = " outcome"
180
218
) %> %
181
219
step_epi_naomit() %> %
@@ -192,10 +230,6 @@ arx_class_epi_workflow <- function(
192
230
)
193
231
}
194
232
195
-
196
- forecast_date <- args_list $ forecast_date %|| % max(epi_data $ time_value )
197
- target_date <- args_list $ target_date %|| % (forecast_date + args_list $ ahead )
198
-
199
233
# --- postprocessor
200
234
f <- frosting() %> % layer_predict() # %>% layer_naomit()
201
235
f <- layer_add_forecast_date(f , forecast_date = forecast_date ) %> %
@@ -260,13 +294,14 @@ arx_class_args_list <- function(
260
294
n_training = Inf ,
261
295
forecast_date = NULL ,
262
296
target_date = NULL ,
297
+ adjust_latency = c(" none" , " extend_ahead" , " extend_lags" , " locf" ),
298
+ warn_latency = TRUE ,
263
299
outcome_transform = c(" growth_rate" , " lag_difference" ),
264
300
breaks = 0.25 ,
265
301
horizon = 7L ,
266
302
method = c(" rel_change" , " linear_reg" ),
267
303
log_scale = FALSE ,
268
304
additional_gr_args = list (),
269
- nafill_buffer = Inf ,
270
305
check_enough_data_n = NULL ,
271
306
check_enough_data_epi_keys = NULL ,
272
307
... ) {
@@ -276,15 +311,15 @@ arx_class_args_list <- function(
276
311
method <- rlang :: arg_match(method )
277
312
outcome_transform <- rlang :: arg_match(outcome_transform )
278
313
279
- arg_is_scalar(ahead , n_training , horizon , log_scale )
314
+ adjust_latency <- rlang :: arg_match(adjust_latency )
315
+ arg_is_scalar(ahead , n_training , horizon , log_scale , adjust_latency , warn_latency )
280
316
arg_is_scalar(forecast_date , target_date , allow_null = TRUE )
281
317
arg_is_date(forecast_date , target_date , allow_null = TRUE )
282
318
arg_is_nonneg_int(ahead , lags , horizon )
283
319
arg_is_numeric(breaks )
284
320
arg_is_lgl(log_scale )
285
321
arg_is_pos(n_training )
286
322
if (is.finite(n_training )) arg_is_pos_int(n_training )
287
- if (is.finite(nafill_buffer )) arg_is_pos_int(nafill_buffer , allow_null = TRUE )
288
323
if (! is.list(additional_gr_args )) {
289
324
cli_abort(c(
290
325
" `additional_gr_args` must be a {.cls list}." ,
@@ -297,10 +332,13 @@ arx_class_args_list <- function(
297
332
298
333
if (! is.null(forecast_date ) && ! is.null(target_date )) {
299
334
if (forecast_date + ahead != target_date ) {
300
- cli :: cli_warn(c(
301
- " `forecast_date` + `ahead` must equal `target_date`." ,
302
- i = " {.val {forecast_date}} + {.val {ahead}} != {.val {target_date}}."
303
- ))
335
+ cli_warn(
336
+ paste0(
337
+ " `forecast_date` {.val {forecast_date}} +" ,
338
+ " `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}."
339
+ ),
340
+ class = " epipredict__arx_args__inconsistent_target_ahead_forecaste_date"
341
+ )
304
342
}
305
343
}
306
344
@@ -318,13 +356,13 @@ arx_class_args_list <- function(
318
356
breaks ,
319
357
forecast_date ,
320
358
target_date ,
359
+ adjust_latency ,
321
360
outcome_transform ,
322
361
max_lags ,
323
362
horizon ,
324
363
method ,
325
364
log_scale ,
326
365
additional_gr_args ,
327
- nafill_buffer ,
328
366
check_enough_data_n ,
329
367
check_enough_data_epi_keys
330
368
),
@@ -337,5 +375,3 @@ print.arx_class <- function(x, ...) {
337
375
name <- " ARX Classifier"
338
376
NextMethod(name = name , ... )
339
377
}
340
-
341
- # this is a trivial change to induce a check
0 commit comments