Skip to content

Commit 3892a53

Browse files
authored
Merge pull request #296 from cmu-delphi/adjustAhead
Adjust ahead
2 parents 3f174fc + 053b501 commit 3892a53

File tree

90 files changed

+3321
-719
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+3321
-719
lines changed

.Rbuildignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@
2020
^doc$
2121
^Meta$
2222
^.lintr$
23-
^.venv$
23+
^.venv$
24+
^inst/templates$

DESCRIPTION

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.1.0
3+
Version: 0.1.1
44
Authors@R: c(
55
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
@@ -40,6 +40,7 @@ Imports:
4040
magrittr,
4141
recipes (>= 1.0.4),
4242
rlang (>= 1.1.0),
43+
purrr,
4344
stats,
4445
tibble,
4546
tidyr,

DEVELOPMENT.md

+2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ The `main` version is available at `file:///<local path>/epidatr/epipredict/inde
3232
You can also build the docs manually and launch the site with python. From the terminal, this looks like
3333

3434
```bash
35+
R -e 'pkgdown::clean_site()'
3536
R -e 'devtools::document()'
37+
R -e 'pkgdown::build_site()'
3638
python -m http.server -d docs
3739
```
3840

NAMESPACE

+25-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ S3method(autoplot,canned_epipred)
1919
S3method(autoplot,epi_workflow)
2020
S3method(bake,check_enough_train_data)
2121
S3method(bake,epi_recipe)
22+
S3method(bake,step_adjust_latency)
2223
S3method(bake,step_epi_ahead)
2324
S3method(bake,step_epi_lag)
2425
S3method(bake,step_epi_slide)
@@ -58,6 +59,7 @@ S3method(predict,epi_workflow)
5859
S3method(predict,flatline)
5960
S3method(prep,check_enough_train_data)
6061
S3method(prep,epi_recipe)
62+
S3method(prep,step_adjust_latency)
6163
S3method(prep,step_epi_ahead)
6264
S3method(prep,step_epi_lag)
6365
S3method(prep,step_epi_slide)
@@ -87,6 +89,7 @@ S3method(print,layer_quantile_distn)
8789
S3method(print,layer_residual_quantiles)
8890
S3method(print,layer_threshold)
8991
S3method(print,layer_unnest)
92+
S3method(print,step_adjust_latency)
9093
S3method(print,step_epi_ahead)
9194
S3method(print,step_epi_lag)
9295
S3method(print,step_epi_slide)
@@ -195,6 +198,7 @@ export(remove_frosting)
195198
export(remove_model)
196199
export(slather)
197200
export(smooth_quantile_reg)
201+
export(step_adjust_latency)
198202
export(step_epi_ahead)
199203
export(step_epi_lag)
200204
export(step_epi_naomit)
@@ -225,6 +229,7 @@ importFrom(checkmate,test_numeric)
225229
importFrom(checkmate,test_scalar)
226230
importFrom(cli,cli_abort)
227231
importFrom(cli,cli_warn)
232+
importFrom(dplyr,"%>%")
228233
importFrom(dplyr,across)
229234
importFrom(dplyr,all_of)
230235
importFrom(dplyr,any_of)
@@ -235,13 +240,20 @@ importFrom(dplyr,everything)
235240
importFrom(dplyr,filter)
236241
importFrom(dplyr,full_join)
237242
importFrom(dplyr,group_by)
243+
importFrom(dplyr,group_by_at)
244+
importFrom(dplyr,join_by)
238245
importFrom(dplyr,left_join)
239246
importFrom(dplyr,mutate)
247+
importFrom(dplyr,n)
248+
importFrom(dplyr,pull)
240249
importFrom(dplyr,relocate)
241250
importFrom(dplyr,rename)
251+
importFrom(dplyr,rowwise)
242252
importFrom(dplyr,select)
243253
importFrom(dplyr,summarise)
244254
importFrom(dplyr,summarize)
255+
importFrom(dplyr,tibble)
256+
importFrom(dplyr,tribble)
245257
importFrom(dplyr,ungroup)
246258
importFrom(epiprocess,epi_slide)
247259
importFrom(epiprocess,growth_rate)
@@ -255,18 +267,20 @@ importFrom(ggplot2,geom_line)
255267
importFrom(ggplot2,geom_linerange)
256268
importFrom(ggplot2,geom_point)
257269
importFrom(ggplot2,geom_ribbon)
270+
importFrom(glue,glue)
271+
importFrom(hardhat,extract_recipe)
258272
importFrom(hardhat,refresh_blueprint)
259273
importFrom(hardhat,run_mold)
260274
importFrom(magrittr,"%>%")
261275
importFrom(recipes,bake)
276+
importFrom(recipes,detect_step)
262277
importFrom(recipes,prep)
263278
importFrom(recipes,rand_id)
264279
importFrom(rlang,"!!!")
265280
importFrom(rlang,"!!")
266281
importFrom(rlang,"%@%")
267282
importFrom(rlang,"%||%")
268283
importFrom(rlang,":=")
269-
importFrom(rlang,abort)
270284
importFrom(rlang,arg_match)
271285
importFrom(rlang,as_function)
272286
importFrom(rlang,caller_arg)
@@ -276,16 +290,19 @@ importFrom(rlang,enquos)
276290
importFrom(rlang,expr)
277291
importFrom(rlang,global_env)
278292
importFrom(rlang,inject)
293+
importFrom(rlang,is_empty)
279294
importFrom(rlang,is_logical)
280295
importFrom(rlang,is_null)
281296
importFrom(rlang,is_true)
297+
importFrom(rlang,list2)
282298
importFrom(rlang,set_names)
283299
importFrom(rlang,sym)
284300
importFrom(stats,as.formula)
285301
importFrom(stats,family)
286302
importFrom(stats,lm)
287303
importFrom(stats,median)
288304
importFrom(stats,model.frame)
305+
importFrom(stats,na.omit)
289306
importFrom(stats,poly)
290307
importFrom(stats,predict)
291308
importFrom(stats,qnorm)
@@ -294,6 +311,12 @@ importFrom(stats,residuals)
294311
importFrom(tibble,as_tibble)
295312
importFrom(tibble,tibble)
296313
importFrom(tidyr,crossing)
314+
importFrom(tidyr,drop_na)
315+
importFrom(tidyr,expand_grid)
316+
importFrom(tidyr,fill)
317+
importFrom(tidyr,unnest)
318+
importFrom(tidyselect,all_of)
319+
importFrom(utils,capture.output)
297320
importFrom(vctrs,as_list_of)
298321
importFrom(vctrs,field)
299322
importFrom(vctrs,new_rcrd)
@@ -303,3 +326,4 @@ importFrom(vctrs,vec_data)
303326
importFrom(vctrs,vec_ptype_abbr)
304327
importFrom(vctrs,vec_ptype_full)
305328
importFrom(vctrs,vec_recycle_common)
329+
importFrom(workflows,extract_preprocessor)

NEWS.md

+7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicate PR's.
44

5+
# epipredict 0.2
6+
7+
## features
8+
- Add `step_adjust_latency`, which give several methods to adjust the forecast if the `forecast_date` is after the last day of data.
9+
10+
## bugfixes
11+
512
# epipredict 0.1
613

714
- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`

R/arx_classifier.R

+63-27
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,18 @@ arx_classifier <- function(
5555
wf <- arx_class_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
5656
wf <- fit(wf, epi_data)
5757

58+
if (args_list$adjust_latency == "none") {
59+
forecast_date_default <- max(epi_data$time_value)
60+
if (!is.null(args_list$forecast_date) && args_list$forecast_date != forecast_date_default) {
61+
cli_warn("The specified forecast date {args_list$forecast_date} doesn't match the date from which the forecast is occurring {forecast_date}.")
62+
}
63+
} else {
64+
forecast_date_default <- attributes(epi_data)$metadata$as_of
65+
}
66+
forecast_date <- args_list$forecast_date %||% forecast_date_default
67+
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
5868
preds <- forecast(
5969
wf,
60-
fill_locf = TRUE,
61-
n_recent = args_list$nafill_buffer,
62-
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
6370
) %>%
6471
as_tibble() %>%
6572
select(-time_value)
@@ -125,27 +132,39 @@ arx_class_epi_workflow <- function(
125132
if (!(is.null(trainer) || is_classification(trainer))) {
126133
cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
127134
}
135+
136+
if (args_list$adjust_latency == "none") {
137+
forecast_date_default <- max(epi_data$time_value)
138+
if (!is.null(args_list$forecast_date) && args_list$forecast_date != forecast_date_default) {
139+
cli_warn("The specified forecast date {args_list$forecast_date} doesn't match the date from which the forecast is occurring {forecast_date}.")
140+
}
141+
} else {
142+
forecast_date_default <- attributes(epi_data)$metadata$as_of
143+
}
144+
forecast_date <- args_list$forecast_date %||% forecast_date_default
145+
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
146+
128147
lags <- arx_lags_validator(predictors, args_list$lags)
129148

130149
# --- preprocessor
131150
# ------- predictors
132151
r <- epi_recipe(epi_data) %>%
133152
step_growth_rate(
134-
dplyr::all_of(predictors),
153+
all_of(predictors),
135154
role = "grp",
136155
horizon = args_list$horizon,
137156
method = args_list$method,
138157
log_scale = args_list$log_scale,
139158
additional_gr_args_list = args_list$additional_gr_args
140159
)
141160
for (l in seq_along(lags)) {
142-
p <- predictors[l]
143-
p <- as.character(glue::glue_data(args_list, "gr_{horizon}_{method}_{p}"))
144-
r <- step_epi_lag(r, !!p, lag = lags[[l]])
161+
pred_names <- predictors[l]
162+
pred_names <- as.character(glue::glue_data(args_list, "gr_{horizon}_{method}_{pred_names}"))
163+
r <- step_epi_lag(r, !!pred_names, lag = lags[[l]])
145164
}
146165
# ------- outcome
147166
if (args_list$outcome_transform == "lag_difference") {
148-
o <- as.character(
167+
pre_out_name <- as.character(
149168
glue::glue_data(args_list, "lag_diff_{horizon}_{outcome}")
150169
)
151170
r <- r %>%
@@ -156,7 +175,7 @@ arx_class_epi_workflow <- function(
156175
)
157176
}
158177
if (args_list$outcome_transform == "growth_rate") {
159-
o <- as.character(
178+
pre_out_name <- as.character(
160179
glue::glue_data(args_list, "gr_{horizon}_{method}_{outcome}")
161180
)
162181
if (!(outcome %in% predictors)) {
@@ -171,11 +190,30 @@ arx_class_epi_workflow <- function(
171190
)
172191
}
173192
}
174-
o2 <- rlang::sym(paste0("ahead_", args_list$ahead, "_", o))
193+
# regex that will match any amount of adjustment for the ahead
194+
ahead_out_name_regex <- glue::glue("ahead_[0-9]*_{pre_out_name}")
195+
method_adjust_latency <- args_list$adjust_latency
196+
if (method_adjust_latency != "none") {
197+
if (method_adjust_latency != "extend_ahead") {
198+
cli_abort("only extend_ahead is currently supported",
199+
class = "epipredict__arx_classifier__adjust_latency_unsupported_method"
200+
)
201+
}
202+
r <- r %>% step_adjust_latency(!!pre_out_name,
203+
fixed_forecast_date = forecast_date,
204+
method = method_adjust_latency
205+
)
206+
}
175207
r <- r %>%
176-
step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome") %>%
177-
recipes::step_mutate(
178-
outcome_class = cut(!!o2, breaks = args_list$breaks),
208+
step_epi_ahead(!!pre_out_name, ahead = args_list$ahead, role = "pre-outcome")
209+
r <- r %>%
210+
step_mutate(
211+
across(
212+
matches(ahead_out_name_regex),
213+
~ cut(.x, breaks = args_list$breaks),
214+
.names = "outcome_class",
215+
.unpack = TRUE
216+
),
179217
role = "outcome"
180218
) %>%
181219
step_epi_naomit() %>%
@@ -192,10 +230,6 @@ arx_class_epi_workflow <- function(
192230
)
193231
}
194232

195-
196-
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
197-
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
198-
199233
# --- postprocessor
200234
f <- frosting() %>% layer_predict() # %>% layer_naomit()
201235
f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>%
@@ -260,13 +294,14 @@ arx_class_args_list <- function(
260294
n_training = Inf,
261295
forecast_date = NULL,
262296
target_date = NULL,
297+
adjust_latency = c("none", "extend_ahead", "extend_lags", "locf"),
298+
warn_latency = TRUE,
263299
outcome_transform = c("growth_rate", "lag_difference"),
264300
breaks = 0.25,
265301
horizon = 7L,
266302
method = c("rel_change", "linear_reg"),
267303
log_scale = FALSE,
268304
additional_gr_args = list(),
269-
nafill_buffer = Inf,
270305
check_enough_data_n = NULL,
271306
check_enough_data_epi_keys = NULL,
272307
...) {
@@ -276,15 +311,15 @@ arx_class_args_list <- function(
276311
method <- rlang::arg_match(method)
277312
outcome_transform <- rlang::arg_match(outcome_transform)
278313

279-
arg_is_scalar(ahead, n_training, horizon, log_scale)
314+
adjust_latency <- rlang::arg_match(adjust_latency)
315+
arg_is_scalar(ahead, n_training, horizon, log_scale, adjust_latency, warn_latency)
280316
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
281317
arg_is_date(forecast_date, target_date, allow_null = TRUE)
282318
arg_is_nonneg_int(ahead, lags, horizon)
283319
arg_is_numeric(breaks)
284320
arg_is_lgl(log_scale)
285321
arg_is_pos(n_training)
286322
if (is.finite(n_training)) arg_is_pos_int(n_training)
287-
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
288323
if (!is.list(additional_gr_args)) {
289324
cli_abort(c(
290325
"`additional_gr_args` must be a {.cls list}.",
@@ -297,10 +332,13 @@ arx_class_args_list <- function(
297332

298333
if (!is.null(forecast_date) && !is.null(target_date)) {
299334
if (forecast_date + ahead != target_date) {
300-
cli::cli_warn(c(
301-
"`forecast_date` + `ahead` must equal `target_date`.",
302-
i = "{.val {forecast_date}} + {.val {ahead}} != {.val {target_date}}."
303-
))
335+
cli_warn(
336+
paste0(
337+
"`forecast_date` {.val {forecast_date}} +",
338+
" `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}."
339+
),
340+
class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date"
341+
)
304342
}
305343
}
306344

@@ -318,13 +356,13 @@ arx_class_args_list <- function(
318356
breaks,
319357
forecast_date,
320358
target_date,
359+
adjust_latency,
321360
outcome_transform,
322361
max_lags,
323362
horizon,
324363
method,
325364
log_scale,
326365
additional_gr_args,
327-
nafill_buffer,
328366
check_enough_data_n,
329367
check_enough_data_epi_keys
330368
),
@@ -337,5 +375,3 @@ print.arx_class <- function(x, ...) {
337375
name <- "ARX Classifier"
338376
NextMethod(name = name, ...)
339377
}
340-
341-
# this is a trivial change to induce a check

0 commit comments

Comments
 (0)