Skip to content

Commit dc6d17d

Browse files
committed
separate step version
1 parent 039e714 commit dc6d17d

File tree

4 files changed

+422
-95
lines changed

4 files changed

+422
-95
lines changed

R/epi_shift.R

+51-30
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,64 @@
22
#'
33
#' This is a lower-level function. As such it performs no error checking.
44
#'
5-
#' @param x Data frame. Variables to shift
6-
#' @param shifts List. Each list element is a vector of shifts.
7-
#' Negative values produce leads. The list should have the same
8-
#' length as the number of columns in `x`.
9-
#' @param time_value Vector. Same length as `x` giving time stamps.
10-
#' @param keys Data frame, vector, or `NULL`. Additional grouping vars.
11-
#' @param out_name Chr. The output list will use this as a prefix.
5+
#' @param x Data frame.
6+
#' @param shift_val a single integer. Negative values produce leads.
7+
#' @param newname the name for the newly shifted column
8+
#' @param key_cols vector, or `NULL`. Additional grouping vars.
129
#'
1310
#' @keywords internal
1411
#'
1512
#' @return a list of tibbles
16-
epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") {
17-
if (!is.data.frame(x)) x <- data.frame(x)
18-
if (is.null(keys)) keys <- rep("empty", nrow(x))
19-
p_in <- ncol(x)
20-
out_list <- tibble::tibble(i = 1:p_in, shift = shifts) %>%
21-
tidyr::unchop(shift) %>% # what is chop
22-
dplyr::mutate(name = paste0(out_name, 1:nrow(.))) %>%
23-
# One list element for each shifted feature
24-
pmap(function(i, shift, name) {
25-
tibble(keys,
26-
time_value = time_value + shift, # Shift back
27-
!!name := x[[i]]
28-
)
29-
})
30-
if (is.data.frame(keys)) {
31-
common_names <- c(names(keys), "time_value")
32-
} else {
33-
common_names <- c("keys", "time_value")
34-
}
35-
36-
reduce(out_list, dplyr::full_join, by = common_names)
37-
}
38-
3913
epi_shift_single <- function(x, col, shift_val, newname, key_cols) {
4014
x %>%
4115
dplyr::select(tidyselect::all_of(c(key_cols, col))) %>%
4216
dplyr::mutate(time_value = time_value + shift_val) %>%
4317
dplyr::rename(!!newname := {{ col }})
4418
}
19+
20+
#' lags move columns forward to bring the past up to today, while aheads drag
21+
#' the future back to today
22+
get_sign <- function(object) {
23+
if (object$prefix == "lag_") {
24+
return(1)
25+
} else {
26+
return(-1)
27+
}
28+
}
29+
30+
#' backend for both `bake.step_epi_ahead` and `bake.step_epi_lag`, performs the
31+
#' checks missing in `epi_shift_single`
32+
#' @keywords internal
33+
add_shifted_columns <- function(new_data, object, amount) {
34+
sign_shift <- get_sign(object)
35+
grid <- tidyr::expand_grid(col = object$columns, amount = amount) %>%
36+
dplyr::mutate(
37+
newname = glue::glue("{object$prefix}{amount}_{col}"),
38+
shift_val = sign_shift * amount,
39+
amount = NULL
40+
)
41+
42+
## ensure no name clashes
43+
new_data_names <- colnames(new_data)
44+
intersection <- new_data_names %in% grid$newname
45+
if (any(intersection)) {
46+
rlang::abort(
47+
paste0(
48+
"Name collision occured in `", class(object)[1],
49+
"`. The following variable names already exists: ",
50+
paste0(new_data_names[intersection], collapse = ", "),
51+
"."
52+
)
53+
)
54+
}
55+
ok <- object$keys
56+
shifted <- reduce(
57+
pmap(grid, epi_shift_single, x = new_data, key_cols = ok),
58+
dplyr::full_join,
59+
by = ok
60+
)
61+
dplyr::full_join(new_data, shifted, by = ok) %>%
62+
dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>%
63+
dplyr::arrange(time_value) %>%
64+
dplyr::ungroup()
65+
}

R/step_adjust_latency.R

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
#' adapt the pipeline to latency in the data
2+
#'
3+
#' In the standard case, the pipeline assumes that the last observation is also
4+
#' the day from which the forecast is being made. `step_adjust_latency` uses the
5+
#' `as_of` date of the `epi_df` as the `forecast_date`. This is most useful in
6+
#' realtime and pseudo-prospective forecasting for data where there is some
7+
#' delay between the day recorded and when that data is available.
8+
#'
9+
#' @param recipe A recipe object. The step will be added to the
10+
#' sequence of operations for this recipe.
11+
#' @param ... One or more selector functions to choose variables for this step.
12+
#' See [recipes::selections()] for more details. Typically you will not need
13+
#' to set this manually, as the necessary adjustments will be done for the
14+
#' predictors and outcome.
15+
#' @param method a character. Determines the method by which the
16+
#' forecast handles latency. All of these assume the forecast date is the
17+
#' `as_of` of the `epi_df`. The options are:
18+
#' - `"extend_ahead"`: Lengthen the ahead so that forecasting from the last
19+
#' observation results in a forecast `ahead` after the `as_of` date. E.g. if
20+
#' there are 3 days of latency between the last observation and the `as_of`
21+
#' date for a 4 day ahead forecast, the ahead used in practice is actually 7.
22+
#' - `"locf"`: carries forward the last observed value(s) up to the forecast
23+
#' date. See the Vignette TODO for equivalents using other steps and more
24+
#' sophisticated methods of extrapolation.
25+
#' - `"extend_lags"`: per `epi_key` and `predictor`, adjusts the lag so that
26+
#' the shortest lag at predict time is at the last observation. E.g. if the
27+
#' lags are `c(0,7,14)` for data that is 3 days latent, the actual lags used
28+
#' become `c(3,10,17)`
29+
#' @param default Determines what fills empty rows
30+
#' left by leading/lagging (defaults to NA).
31+
#' @param prefix a character. The prefix matching the one used in either
32+
#' `step_epi_ahead` if `method="extend_ahead"` or `step_epi_lag`
33+
#' if `method="extend_lags"` or "locf".
34+
#' @param skip A logical. Should the step be skipped when the
35+
#' recipe is baked by [bake()]? While all operations are baked
36+
#' when [prep()] is run, some operations may not be able to be
37+
#' conducted on new data (e.g. processing the outcome variable(s)).
38+
#' Care should be taken when using `skip = TRUE` as it may affect
39+
#' the computations for subsequent operations.
40+
#' @param id A unique identifier for the step
41+
#' @template step-return
42+
#'
43+
#' @details The step assumes that the pipeline has already applied either
44+
#' `step_epi_ahead` or `step_epi_lag` depending on the value of
45+
#' `"method"`, and that `step_epi_naomit` has NOT been run.
46+
#'
47+
#' The `prefix` and `id` arguments are unchangeable to ensure that the code runs
48+
#' properly and to avoid inconsistency with naming. For `step_epi_ahead`, they
49+
#' are always set to `"ahead_"` and `"epi_ahead"` respectively, while for
50+
#' `step_epi_lag`, they are set to `"lag_"` and `"epi_lag`, respectively.
51+
#'
52+
#' @family row operation steps
53+
#' @rdname step_adjust_latency
54+
#' @export
55+
#' @examples
56+
#' r <- epi_recipe(case_death_rate_subset) %>%
57+
#' step_epi_ahead(death_rate, ahead = 7) %>%
58+
#' # step_adjust_latency(method = "extend_ahead") %>%
59+
#' step_epi_lag(death_rate, lag = c(0, 7, 14))
60+
#' r
61+
step_adjust_latency <-
62+
function(recipe,
63+
...,
64+
role = NA,
65+
trained = FALSE,
66+
method = c(
67+
"extend_ahead",
68+
"locf",
69+
"extend_lags"
70+
),
71+
default = NA,
72+
skip = FALSE,
73+
prefix = NULL,
74+
columns = NULL,
75+
id = recipes::rand_id("epi_lag")) {
76+
if (!is_epi_recipe(recipe)) {
77+
rlang::abort("This recipe step can only operate on an `epi_recipe`.")
78+
}
79+
if (!is.null(columns)) {
80+
rlang::abort(c("The `columns` argument must be `NULL.",
81+
i = "Use `tidyselect` methods to choose columns to lag."
82+
))
83+
}
84+
85+
method <- rlang::arg_match(method)
86+
if (method == "extend_ahead") {
87+
prefix <- "ahead_"
88+
} else {
89+
prefix <- "lag_"
90+
}
91+
92+
arg_is_chr_scalar(prefix, id, method)
93+
recipes::add_step(
94+
recipe,
95+
step_adjust_latency_new(
96+
terms = dplyr::enquos(...),
97+
role = role,
98+
method = method,
99+
info = NULL,
100+
trained = trained,
101+
prefix = prefix,
102+
default = default,
103+
keys = epi_keys(recipe),
104+
columns = columns,
105+
skip = skip,
106+
id = id
107+
)
108+
)
109+
}
110+
111+
step_adjust_latency_new <-
112+
function(terms, role, trained, prefix, default, keys, method, info,
113+
columns, skip, id) {
114+
step(
115+
subclass = "adjust_latency",
116+
terms = terms,
117+
role = role,
118+
method = method,
119+
info = info,
120+
trained = trained,
121+
prefix = prefix,
122+
default = default,
123+
keys = keys,
124+
columns = columns,
125+
skip = skip,
126+
id = id
127+
)
128+
}
129+
130+
#' @export
131+
prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
132+
if ((x$method == "extend_ahead") && (!("outcome" %in% info$role))) {
133+
cli::cli_abort(glue::glue(c('If `method` is `"extend_ahead"`, then a step ",
134+
"must have already added an outcome .')))
135+
} else if (!("predictor" %in% info$role)) {
136+
cli::cli_abort('If `method` is `"extend_lags"` or `"locf"`, then a step ",
137+
"must have already added a predictor.')
138+
}
139+
# TODO info here is probably not the best way to handle this, hypothetically I
140+
# get an info object during baking
141+
step_adjust_latency_new(
142+
terms = x$terms,
143+
role = x$role,
144+
trained = TRUE,
145+
prefix = x$prefix,
146+
default = x$default,
147+
keys = x$keys,
148+
method = x$method,
149+
info = info,
150+
columns = recipes::recipes_eval_select(x$terms, training, info),
151+
skip = x$skip,
152+
id = x$id
153+
)
154+
}
155+
156+
#' various ways of handling differences between the `as_of` date and the maximum
157+
#' time value
158+
#' @description
159+
#' adjust the ahead so that we will be predicting `ahead` days after the `as_of`
160+
#' date, rather than relative to the last day of data
161+
#' @param new_data assumes that this already has lag/ahead columns that we need
162+
#' to adjust
163+
#' @importFrom dplyr %>%
164+
#' @keywords internal
165+
#' @importFrom dplyr %>% pull
166+
bake.step_adjust_latency <- function(object, new_data, ...) {
167+
sign_shift <- get_sign(object)
168+
# get the columns used, even if it's all of them
169+
terms_used <- object$columns
170+
if (length(terms_used) == 0) {
171+
terms_used <- object$info %>%
172+
filter(role == "raw") %>%
173+
pull(variable)
174+
}
175+
# get and check the max_time and as_of are the right kinds of dates
176+
as_of <- get_asof(object, new_data)
177+
178+
# infer the correct columns to be working with from the previous
179+
# transformations
180+
shift_cols <- get_shifted_column_tibble(object, new_data, terms_used, as_of,
181+
sign_shift)
182+
183+
if ((object$method == "extend_ahead") || (object$method == "extend_lags")) {
184+
# check that the shift amount isn't too extreme
185+
latency <- max(shift_cols$latency)
186+
i_latency <- which.max(shift_cols$latency)
187+
time_type <- attributes(new_data)$metadata$time_type
188+
if (
189+
(grepl("day", time_type) && (latency >= 10)) ||
190+
(grepl("week", time_type) && (latency >= 4)) ||
191+
((time_type == "yearmonth") && (latency >= 2)) ||
192+
((time_type == "yearquarter") && (latency >= 1)) ||
193+
((time_type == "year") && (latency >= 1))
194+
) {
195+
cli::cli_warn(c(
196+
"!" = glue::glue(
197+
"The shift has been adjusted by {latency}, ",
198+
"which is questionable for it's `time_type` of ",
199+
"{time_type}"
200+
),
201+
"i" = "input ahead: {shift_cols$shifts[[i_latency]]}",
202+
"i" = "shifted ahead: {shift_cols$effective_shift[[i_latency]]}",
203+
"i" = "max_time = {max_time} -> as_of = {as_of}"
204+
))
205+
}
206+
keys <- object$keys
207+
return(
208+
extend_either(new_data, shift_cols, keys)
209+
)
210+
}
211+
}

R/step_epi_shift.R

+2-65
Original file line numberDiff line numberDiff line change
@@ -246,76 +246,13 @@ prep.step_epi_ahead <- function(x, training, info = NULL, ...) {
246246

247247
#' @export
248248
bake.step_epi_lag <- function(object, new_data, ...) {
249-
grid <- tidyr::expand_grid(col = object$columns, lag = object$lag) %>%
250-
dplyr::mutate(
251-
newname = glue::glue("{object$prefix}{lag}_{col}"),
252-
shift_val = lag,
253-
lag = NULL
254-
)
255-
256-
## ensure no name clashes
257-
new_data_names <- colnames(new_data)
258-
intersection <- new_data_names %in% grid$newname
259-
if (any(intersection)) {
260-
rlang::abort(
261-
paste0(
262-
"Name collision occured in `", class(object)[1],
263-
"`. The following variable names already exists: ",
264-
paste0(new_data_names[intersection], collapse = ", "),
265-
"."
266-
)
267-
)
268-
}
269-
ok <- object$keys
270-
shifted <- reduce(
271-
pmap(grid, epi_shift_single, x = new_data, key_cols = ok),
272-
dplyr::full_join,
273-
by = ok
274-
)
275-
276-
dplyr::full_join(new_data, shifted, by = ok) %>%
277-
dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>%
278-
dplyr::arrange(time_value) %>%
279-
dplyr::ungroup()
249+
add_shifted_columns(new_data, object, object$lag)
280250
}
281-
282251
#' @export
283252
bake.step_epi_ahead <- function(object, new_data, ...) {
284-
ahead <- adjust_latency(object, new_data)
285-
grid <- tidyr::expand_grid(col = object$columns, ahead = ahead) %>%
286-
dplyr::mutate(
287-
newname = glue::glue("{object$prefix}{ahead}_{col}"),
288-
shift_val = -ahead,
289-
ahead = NULL
290-
)
291-
292-
## ensure no name clashes
293-
new_data_names <- colnames(new_data)
294-
intersection <- new_data_names %in% grid$newname
295-
if (any(intersection)) {
296-
rlang::abort(
297-
paste0(
298-
"Name collision occured in `", class(object)[1],
299-
"`. The following variable names already exists: ",
300-
paste0(new_data_names[intersection], collapse = ", "),
301-
"."
302-
)
303-
)
304-
}
305-
ok <- object$keys
306-
shifted <- reduce(
307-
pmap(grid, epi_shift_single, x = new_data, key_cols = ok),
308-
dplyr::full_join,
309-
by = ok
310-
)
311-
312-
dplyr::full_join(new_data, shifted, by = ok) %>%
313-
dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>%
314-
dplyr::arrange(time_value) %>%
315-
dplyr::ungroup()
253+
add_shifted_columns(new_data, object, object$ahead)
316254
}
317255

318-
319256
#' @export
320257
print.step_epi_lag <- function(x, width = max(20, options()$width - 30), ...) {
321258
print_epi_step(x$columns, x$terms, x$trained, "Lagging",

0 commit comments

Comments
 (0)