Skip to content

Commit f36f6fa

Browse files
authored
Merge pull request #334 from cmu-delphi/adjustAheadLayerAdditions
initial layer adjustments
2 parents 1ada3d0 + cf8fed6 commit f36f6fa

33 files changed

+757
-351
lines changed

NAMESPACE

+7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ S3method(autoplot,canned_epipred)
1919
S3method(autoplot,epi_workflow)
2020
S3method(bake,check_enough_train_data)
2121
S3method(bake,epi_recipe)
22+
S3method(bake,step_adjust_latency)
2223
S3method(bake,step_epi_ahead)
2324
S3method(bake,step_epi_lag)
2425
S3method(bake,step_growth_rate)
@@ -225,9 +226,12 @@ importFrom(dplyr,"%>%")
225226
importFrom(dplyr,across)
226227
importFrom(dplyr,all_of)
227228
importFrom(dplyr,group_by)
229+
importFrom(dplyr,join_by)
230+
importFrom(dplyr,left_join)
228231
importFrom(dplyr,n)
229232
importFrom(dplyr,pull)
230233
importFrom(dplyr,rowwise)
234+
importFrom(dplyr,select)
231235
importFrom(dplyr,summarise)
232236
importFrom(dplyr,tibble)
233237
importFrom(dplyr,ungroup)
@@ -236,6 +240,7 @@ importFrom(generics,augment)
236240
importFrom(generics,fit)
237241
importFrom(generics,forecast)
238242
importFrom(ggplot2,autoplot)
243+
importFrom(glue,glue)
239244
importFrom(hardhat,refresh_blueprint)
240245
importFrom(hardhat,run_mold)
241246
importFrom(magrittr,"%>%")
@@ -249,6 +254,7 @@ importFrom(rlang,"%||%")
249254
importFrom(rlang,":=")
250255
importFrom(rlang,as_function)
251256
importFrom(rlang,caller_env)
257+
importFrom(rlang,enquos)
252258
importFrom(rlang,global_env)
253259
importFrom(rlang,is_null)
254260
importFrom(rlang,set_names)
@@ -275,3 +281,4 @@ importFrom(vctrs,vec_data)
275281
importFrom(vctrs,vec_ptype_abbr)
276282
importFrom(vctrs,vec_ptype_full)
277283
importFrom(vctrs,vec_recycle_common)
284+
importFrom(workflows,extract_preprocessor)

R/arx_classifier.R

+9-2
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,18 @@ arx_classifier <- function(
5454
wf <- arx_class_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
5555
wf <- generics::fit(wf, epi_data)
5656

57+
latency_adjust_fd <- if (is.null(args_list$adjust_latency)) {
58+
max(epi_data$time_value)
59+
} else {
60+
attributes(epi_data)$metadata$as_of
61+
}
62+
forecast_date <- args_list$forecast_date %||% latency_adjust_fd
63+
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
5764
preds <- forecast(
5865
wf,
59-
fill_locf = TRUE,
66+
fill_locf = is.null(args_list$adjust_latency),
6067
n_recent = args_list$nafill_buffer,
61-
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
68+
forecast_date = forecast_date
6269
) %>%
6370
tibble::as_tibble() %>%
6471
dplyr::select(-time_value)

R/arx_forecaster.R

+53-28
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# TODO add latency to default forecaster
21
#' Direct autoregressive forecaster with covariates
32
#'
43
#' This is an autoregressive forecasting model for
@@ -54,7 +53,7 @@ arx_forecaster <- function(
5453

5554
preds <- forecast(
5655
wf,
57-
fill_locf = TRUE,
56+
fill_locf = is.null(args_list$adjust_latency),
5857
n_recent = args_list$nafill_buffer,
5958
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
6059
) %>%
@@ -119,6 +118,17 @@ arx_fcast_epi_workflow <- function(
119118
if (!(is.null(trainer) || is_regression(trainer))) {
120119
cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.")
121120
}
121+
# forecast_date is first what they set;
122+
# if they don't and they're not adjusting latency, it defaults to the max time_value
123+
# if they're adjusting as_of, it defaults to the as_of
124+
latency_adjust_fd <- if (is.null(args_list$adjust_latency)) {
125+
max(epi_data$time_value)
126+
} else {
127+
attributes(epi_data)$metadata$as_of
128+
}
129+
forecast_date <- args_list$forecast_date %||% latency_adjust_fd
130+
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
131+
122132
lags <- arx_lags_validator(predictors, args_list$lags)
123133

124134
# --- preprocessor
@@ -128,26 +138,34 @@ arx_fcast_epi_workflow <- function(
128138
r <- step_epi_lag(r, !!p, lag = lags[[l]])
129139
}
130140
r <- r %>%
131-
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
132-
step_epi_naomit() %>%
133-
step_training_window(n_recent = args_list$n_training) %>%
134-
{
135-
if (!is.null(args_list$check_enough_data_n)) {
136-
check_enough_train_data(
137-
.,
138-
all_predictors(),
139-
!!outcome,
140-
n = args_list$check_enough_data_n,
141-
epi_keys = args_list$check_enough_data_epi_keys,
142-
drop_na = FALSE
143-
)
144-
} else {
145-
.
146-
}
141+
step_epi_ahead(!!outcome, ahead = args_list$ahead)
142+
method <- args_list$adjust_latency
143+
if (!is.null(method)) {
144+
if (method == "extend_ahead") {
145+
r <- r %>% step_adjust_latency(all_outcomes(),
146+
fixed_forecast_date = forecast_date,
147+
method = method
148+
)
149+
} else if (method == "extend_lags") {
150+
r <- r %>% step_adjust_latency(all_predictors(),
151+
fixed_forecast_date = forecast_date,
152+
method = method
153+
)
147154
}
155+
}
156+
r <- r %>%
157+
step_epi_naomit() %>%
158+
step_training_window(n_recent = args_list$n_training)
159+
if (!is.null(args_list$check_enough_data_n)) {
160+
r <- r %>% check_enough_train_data(
161+
all_predictors(),
162+
!!outcome,
163+
n = args_list$check_enough_data_n,
164+
epi_keys = args_list$check_enough_data_epi_keys,
165+
drop_na = FALSE
166+
)
167+
}
148168

149-
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
150-
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
151169

152170
# --- postprocessor
153171
f <- frosting() %>% layer_predict() # %>% layer_naomit()
@@ -159,11 +177,11 @@ arx_fcast_epi_workflow <- function(
159177
))
160178
args_list$quantile_levels <- quantile_levels
161179
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
162-
f <- layer_quantile_distn(f, quantile_levels = quantile_levels) %>%
180+
f <- f %>%
181+
layer_quantile_distn(quantile_levels = quantile_levels) %>%
163182
layer_point_from_distn()
164183
} else {
165-
f <- layer_residual_quantiles(
166-
f,
184+
f <- f %>% layer_residual_quantiles(
167185
quantile_levels = args_list$quantile_levels,
168186
symmetrize = args_list$symmetrize,
169187
by_key = args_list$quantile_by_key
@@ -189,10 +207,15 @@ arx_fcast_epi_workflow <- function(
189207
#' @param n_training Integer. An upper limit for the number of rows per
190208
#' key that are used for training
191209
#' (in the time unit of the `epi_df`).
192-
#' @param forecast_date Date. The date on which the forecast is created.
193-
#' The default `NULL` will attempt to determine this automatically.
194-
#' @param target_date Date. The date for which the forecast is intended.
195-
#' The default `NULL` will attempt to determine this automatically.
210+
#' @param forecast_date Date. The date on which the forecast is created. The
211+
#' default `NULL` will attempt to determine this automatically either as the
212+
#' max time value if there is no latency adjustment, or as the `as_of` of
213+
#' `epi_data` if `adjust_latency` is non-`NULL`.
214+
#' @param target_date Date. The date for which the forecast is intended. The
215+
#' default `NULL` will attempt to determine this automatically as
216+
#' `forecast_date + ahead`.
217+
#' @param adjust_latency Character or `NULL`. one of the `method`s of
218+
#' `step_adjust_latency`, or `NULL` (in which case there is no adjustment).
196219
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
197220
#' prediction intervals. These are created by computing the quantiles of
198221
#' training residuals. A `NULL` value will result in point forecasts only.
@@ -238,6 +261,7 @@ arx_args_list <- function(
238261
n_training = Inf,
239262
forecast_date = NULL,
240263
target_date = NULL,
264+
adjust_latency = NULL,
241265
quantile_levels = c(0.05, 0.95),
242266
symmetrize = TRUE,
243267
nonneg = TRUE,
@@ -253,7 +277,7 @@ arx_args_list <- function(
253277

254278
arg_is_scalar(ahead, n_training, symmetrize, nonneg)
255279
arg_is_chr(quantile_by_key, allow_empty = TRUE)
256-
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
280+
arg_is_scalar(forecast_date, target_date, adjust_latency, allow_null = TRUE)
257281
arg_is_date(forecast_date, target_date, allow_null = TRUE)
258282
arg_is_nonneg_int(ahead, lags)
259283
arg_is_lgl(symmetrize, nonneg)
@@ -282,6 +306,7 @@ arx_args_list <- function(
282306
quantile_levels,
283307
forecast_date,
284308
target_date,
309+
adjust_latency,
285310
symmetrize,
286311
nonneg,
287312
max_lags,

R/epi_recipe.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ prep.epi_recipe <- function(
431431
x, training = NULL, fresh = FALSE, verbose = FALSE,
432432
retain = TRUE, log_changes = FALSE, strings_as_factors = TRUE, ...) {
433433
if (is.null(training)) {
434-
cli::cli_warn(c(
434+
cli::cli_warn(paste(
435435
"!" = "No training data was supplied to {.fn prep}.",
436436
"!" = "Unlike a {.cls recipe}, an {.cls epi_recipe} does not ",
437437
"!" = "store the full template data in the object.",
@@ -577,7 +577,6 @@ bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") {
577577
new_data
578578
}
579579

580-
581580
kill_levels <- function(x, keys) {
582581
for (i in which(names(x) %in% keys)) x[[i]] <- list(values = NA, ordered = NA)
583582
x

R/get_test_data.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ get_test_data <- function(
116116
cannot_be_used <- x %>%
117117
dplyr::filter(forecast_date - time_value <= n_recent) %>%
118118
dplyr::mutate(fillers = forecast_date - time_value > min_required) %>%
119-
dplyr::summarize(
119+
dplyr::summarise(
120120
dplyr::across(
121121
-tidyselect::any_of(epi_keys(recipe)),
122122
~ all(is.na(.x[fillers])) & is.na(head(.x[!fillers], 1))

R/layer_add_forecast_date.R

+11-13
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
# TODO adapt this to latency
21
#' Postprocessing step to add the forecast date
32
#'
43
#' @param frosting a `frosting` postprocessor
54
#' @param forecast_date The forecast date to add as a column to the `epi_df`.
6-
#' For most cases, this should be specified in the form "yyyy-mm-dd". Note that
7-
#' when the forecast date is left unspecified, it is set to the maximum time
8-
#' value from the data used in pre-processing, fitting the model, and
9-
#' postprocessing.
5+
#' For most cases, this should be specified in the form "yyyy-mm-dd". Note
6+
#' that when the forecast date is left unspecified, it is set to one of two
7+
#' values. If there is a `step_adjust_latency` step present, it uses the
8+
#' `forecast_date` as set in that function. Otherwise, it uses the maximum
9+
#' `time_value` across the data used for pre-processing, fitting the model,
10+
#' and postprocessing.
1011
#' @param id a random id string
1112
#'
1213
#' @return an updated `frosting` postprocessor
@@ -86,17 +87,14 @@ layer_add_forecast_date_new <- function(forecast_date, id) {
8687
}
8788

8889
#' @export
90+
#' @importFrom workflows extract_preprocessor
8991
slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) {
90-
if (is.null(object$forecast_date)) {
91-
max_time_value <- max(
92-
workflows::extract_preprocessor(workflow)$max_time_value,
92+
forecast_date <- object$forecast_date %||%
93+
get_forecast_date_in_layer(
94+
extract_preprocessor(workflow),
9395
workflow$fit$meta$max_time_value,
94-
max(new_data$time_value)
96+
new_data
9597
)
96-
forecast_date <- max_time_value
97-
} else {
98-
forecast_date <- object$forecast_date
99-
}
10098

10199
expected_time_type <- attr(
102100
workflows::extract_preprocessor(workflow)$template, "metadata"

R/layer_add_target_date.R

+41-21
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
1-
# TODO adapt this to latency
21
#' Postprocessing step to add the target date
32
#'
43
#' @param frosting a `frosting` postprocessor
5-
#' @param target_date The target date to add as a column to the
6-
#' `epi_df`. If there's a forecast date specified in a layer, then
7-
#' it is the forecast date plus `ahead` (from `step_epi_ahead` in
8-
#' the `epi_recipe`). Otherwise, it is the maximum `time_value`
9-
#' (from the data used in pre-processing, fitting the model, and
10-
#' postprocessing) plus `ahead`, where `ahead` has been specified in
11-
#' preprocessing. The user may override these by specifying a
12-
#' target date of their own (of the form "yyyy-mm-dd").
4+
#' @param target_date The target date to add as a column to the `epi_df`. If
5+
#' there's a forecast date specified upstream (either in a
6+
#' `step_adjust_latency` or in a `layer_forecast_date`), then it is the
7+
#' forecast date plus `ahead` (from `step_epi_ahead` in the `epi_recipe`).
8+
#' Otherwise, it is the maximum `time_value` (from the data used in
9+
#' pre-processing, fitting the model, and postprocessing) plus `ahead`, where
10+
#' `ahead` has been specified in preprocessing. The user may override these by
11+
#' specifying a target date of their own (of the form "yyyy-mm-dd").
1312
#' @param id a random id string
1413
#'
1514
#' @return an updated `frosting` postprocessor
1615
#'
1716
#' @details By default, this function assumes that a value for `ahead`
1817
#' has been specified in a preprocessing step (most likely in
19-
#' `step_epi_ahead`). Then, `ahead` is added to the maximum `time_value`
20-
#' in the test data to get the target date.
18+
#' `step_epi_ahead`). Then, `ahead` is added to the `forecast_date`
19+
#' in the test data to get the target date. `forecast_date` can be set in 3 ways:
20+
#' 1. `step_adjust_latency`, which typically uses the training `epi_df`'s `as_of`
21+
#' 2. `layer_add_forecast_date`, which inherits from 1 if not manually specifed
22+
#' 3. if none of those are the case, it is simply the maximum `time_value` over
23+
#' every dataset used (prep, training, and prediction).
2124
#'
2225
#' @export
2326
#' @examples
@@ -41,8 +44,14 @@
4144
#' p <- forecast(wf1)
4245
#' p
4346
#'
44-
#' # Use ahead + max time value from pre, fit, post
45-
#' # which is the same if include `layer_add_forecast_date()`
47+
#' # Use ahead + forecast_date from adjust_latency
48+
#' # setting the `as_of` to something realistic
49+
#' attributes(jhu)$metadata$as_of <- max(jhu$time_value) + 3
50+
#' r <- epi_recipe(jhu) %>%
51+
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
52+
#' step_epi_ahead(death_rate, ahead = 7) %>%
53+
#' step_adjust_latency(method = "extend_ahead") %>%
54+
#' step_epi_naomit()
4655
#' f2 <- frosting() %>%
4756
#' layer_predict() %>%
4857
#' layer_add_target_date() %>%
@@ -52,15 +61,26 @@
5261
#' p2 <- forecast(wf2)
5362
#' p2
5463
#'
55-
#' # Specify own target date
64+
#' # Use ahead + max time value from pre, fit, post
65+
#' # which is the same if include `layer_add_forecast_date()`
5666
#' f3 <- frosting() %>%
5767
#' layer_predict() %>%
58-
#' layer_add_target_date(target_date = "2022-01-08") %>%
68+
#' layer_add_target_date() %>%
5969
#' layer_naomit(.pred)
6070
#' wf3 <- wf %>% add_frosting(f3)
6171
#'
62-
#' p3 <- forecast(wf3)
63-
#' p3
72+
#' p3 <- forecast(wf2)
73+
#' p2
74+
#'
75+
#' # Specify own target date
76+
#' f4 <- frosting() %>%
77+
#' layer_predict() %>%
78+
#' layer_add_target_date(target_date = "2022-01-08") %>%
79+
#' layer_naomit(.pred)
80+
#' wf4 <- wf %>% add_frosting(f4)
81+
#'
82+
#' p4 <- forecast(wf4)
83+
#' p4
6484
layer_add_target_date <-
6585
function(frosting, target_date = NULL, id = rand_id("add_target_date")) {
6686
arg_is_chr_scalar(id)
@@ -108,13 +128,13 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data
108128
ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")
109129
target_date <- forecast_date + ahead
110130
} else {
111-
max_time_value <- max(
112-
workflows::extract_preprocessor(workflow)$max_time_value,
131+
forecast_date <- get_forecast_date_in_layer(
132+
extract_preprocessor(workflow),
113133
workflow$fit$meta$max_time_value,
114-
max(new_data$time_value)
134+
new_data
115135
)
116136
ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")
117-
target_date <- max_time_value + ahead
137+
target_date <- forecast_date + ahead
118138
}
119139

120140
object$target_date <- target_date

R/layer_residual_quantiles.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ slather.layer_residual_quantiles <-
115115
}
116116

117117
r <- r %>%
118-
dplyr::summarize(
118+
dplyr::summarise(
119119
dstn = list(quantile(
120120
c(.resid, s * .resid),
121121
probs = object$quantile_levels, na.rm = TRUE

0 commit comments

Comments
 (0)