Skip to content

Commit a415100

Browse files
authored
Merge pull request #408 from cmu-delphi/prodFixes
Prod fixes
2 parents 4b9fc72 + fb7d6ba commit a415100

17 files changed

+162
-45
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.1.1
3+
Version: 0.1.2
44
Authors@R: c(
55
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
66

77
## features
88
- Add `step_adjust_latency`, which give several methods to adjust the forecast if the `forecast_date` is after the last day of data.
9+
- (temporary) ahead negative is allowed for `step_epi_ahead` until we have `step_epi_shift`
910

1011
## bugfixes
12+
- shifting no columns results in no error for either `step_epi_ahead` and `step_epi_lag`
1113

1214
# epipredict 0.1
1315

R/epi_shift.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ get_sign <- function(object) {
4242
add_shifted_columns <- function(new_data, object) {
4343
grid <- object$shift_grid
4444

45+
if (nrow(object$shift_grid) == 0) {
46+
# we're not shifting any rows, so this is a no-op
47+
return(new_data)
48+
}
4549
## ensure no name clashes
4650
new_data_names <- colnames(new_data)
4751
intersection <- new_data_names %in% grid$newname

R/epi_workflow.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), .
167167
components$forged <- hardhat::forge(new_data,
168168
blueprint = components$mold$blueprint
169169
)
170+
170171
components$keys <- grab_forged_keys(components$forged, object, new_data)
171172
components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...)
172173
components$predictions

R/step_adjust_latency.R

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@
152152
#'
153153
#' Note that this is a separate concern from different latencies across
154154
#' different *data columns*, which is only handled by the choice of `method`.
155+
#' @param keys_to_ignore a list of character vectors. Set this to avoid using
156+
#' specific key values in the `epi_keys_checked` to set latency. For example,
157+
#' say you have two locations `pr` and `gu` which have useful training data,
158+
#' but have stopped providing up-to-date information, and so are no longer
159+
#' part of the test set. Setting `keys_to_ignore = list(geo_value = c("pr",
160+
#' "gu"))` will exclude them from the latency calculation.
155161
#' @param fixed_latency either a positive integer, or a labeled positive integer
156162
#' vector. Cannot be set at the same time as `fixed_forecast_date`. If
157163
#' non-`NULL`, the amount to offset the ahead or lag by. If a single integer,
@@ -203,6 +209,7 @@ step_adjust_latency <-
203209
"extend_lags"
204210
),
205211
epi_keys_checked = NULL,
212+
keys_to_ignore = c(),
206213
fixed_latency = NULL,
207214
fixed_forecast_date = NULL,
208215
check_latency_length = TRUE,
@@ -228,6 +235,7 @@ step_adjust_latency <-
228235
metadata = NULL,
229236
method = method,
230237
epi_keys_checked = epi_keys_checked,
238+
keys_to_ignore = keys_to_ignore,
231239
check_latency_length = check_latency_length,
232240
columns = NULL,
233241
skip = FALSE,
@@ -239,7 +247,7 @@ step_adjust_latency <-
239247
step_adjust_latency_new <-
240248
function(terms, role, trained, fixed_forecast_date, forecast_date, latency,
241249
latency_table, latency_sign, metadata, method, epi_keys_checked,
242-
check_latency_length, columns, skip, id) {
250+
keys_to_ignore, check_latency_length, columns, skip, id) {
243251
step(
244252
subclass = "adjust_latency",
245253
terms = terms,
@@ -253,6 +261,7 @@ step_adjust_latency_new <-
253261
metadata = metadata,
254262
method = method,
255263
epi_keys_checked = epi_keys_checked,
264+
keys_to_ignore = keys_to_ignore,
256265
check_latency_length = check_latency_length,
257266
columns = columns,
258267
skip = skip,
@@ -271,7 +280,7 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
271280

272281
latency_table <- get_latency_table(
273282
training, NULL, forecast_date, latency,
274-
get_sign(x), x$epi_keys_checked, info, x$terms
283+
get_sign(x), x$epi_keys_checked, x$keys_to_ignore, info, x$terms
275284
)
276285
# get the columns used, even if it's all of them
277286
terms_used <- x$terms
@@ -293,6 +302,7 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
293302
metadata = attributes(training)$metadata,
294303
method = x$method,
295304
epi_keys_checked = x$epi_keys_checked,
305+
keys_to_ignore = x$keys_to_ignore,
296306
check_latency_length = x$check_latency_length,
297307
columns = recipes_eval_select(latency_table$col_name, training, info),
298308
skip = x$skip,
@@ -305,10 +315,8 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
305315
#' @export
306316
bake.step_adjust_latency <- function(object, new_data, ...) {
307317
if (!inherits(new_data, "epi_df") || is.null(attributes(new_data)$metadata$as_of)) {
308-
new_data <- as_epi_df(new_data)
318+
new_data <- as_epi_df(new_data, as_of = object$forecast_date, other_keys = object$metadata$other_keys %||% character())
309319
attributes(new_data)$metadata <- object$metadata
310-
attributes(new_data)$metadata$as_of <- object$forecast_date
311-
} else {
312320
compare_bake_prep_latencies(object, new_data)
313321
}
314322
if (object$method == "locf") {

R/step_epi_shift.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ step_epi_ahead <-
111111
i = "Did you perhaps pass an integer in `...` accidentally?"
112112
))
113113
}
114-
arg_is_nonneg_int(ahead)
115114
arg_is_chr_scalar(prefix, id)
116115

117116
recipes::add_step(

R/utils-latency.R

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,18 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns
5050
)
5151
}
5252
}
53+
max_time <- get_max_time(new_data, epi_keys_checked, columns)
5354
# the source data determines the actual time_values
54-
# these are the non-na time_values;
55-
# get the minimum value across the checked epi_keys' maximum time values
56-
max_time <- new_data %>%
57-
select(all_of(columns)) %>%
58-
drop_na()
59-
# null and "" don't work in `group_by`
60-
if (!is.null(epi_keys_checked) && (epi_keys_checked != "")) {
61-
max_time <- max_time %>% group_by(get(epi_keys_checked))
62-
}
63-
max_time <- max_time %>%
64-
summarise(time_value = max(time_value)) %>%
65-
pull(time_value) %>%
66-
min()
6755
if (is.null(latency)) {
6856
forecast_date <- attributes(new_data)$metadata$as_of
6957
} else {
58+
if (is.null(max_time)) {
59+
cli_abort("max_time is null. This likely means there is one of {columns} that is all `NA`")
60+
}
7061
forecast_date <- max_time + latency
7162
}
7263
# make sure the as_of is sane
73-
if (!inherits(forecast_date, class(max_time)) & !inherits(forecast_date, "POSIXt")) {
64+
if (!inherits(forecast_date, class(new_data$time_value)) & !inherits(forecast_date, "POSIXt")) {
7465
cli_abort(
7566
paste(
7667
"the data matrix `forecast_date` value is {forecast_date}, ",
@@ -84,13 +75,13 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns
8475
if (is.null(forecast_date) || is.na(forecast_date)) {
8576
cli_warn(
8677
paste(
87-
"epi_data's `forecast_date` was {forecast_date}, setting to ",
88-
"the latest time value, {max_time}."
78+
"epi_data's `forecast_date` was `NA`, setting to ",
79+
"the latest non-`NA` time value for these columns, {max_time}."
8980
),
9081
class = "epipredict__get_forecast_date__max_time_warning"
9182
)
9283
forecast_date <- max_time
93-
} else if (forecast_date < max_time) {
84+
} else if (!is.null(max_time) && (forecast_date < max_time)) {
9485
cli_abort(
9586
paste(
9687
"`forecast_date` ({(forecast_date)}) is before the most ",
@@ -101,22 +92,49 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns
10192
)
10293
}
10394
# TODO cover the rest of the possible types for as_of and max_time...
104-
if (inherits(max_time, "Date")) {
95+
if (inherits(new_data$time_value, "Date")) {
10596
forecast_date <- as.Date(forecast_date)
10697
}
10798
return(forecast_date)
10899
}
109100

101+
get_max_time <- function(new_data, epi_keys_checked, columns) {
102+
# these are the non-na time_values;
103+
# get the minimum value across the checked epi_keys' maximum time values
104+
max_time <- new_data %>%
105+
select(all_of(columns)) %>%
106+
drop_na()
107+
if (nrow(max_time) == 0) {
108+
return(NULL)
109+
}
110+
# null and "" don't work in `group_by`
111+
if (!is.null(epi_keys_checked) && all(epi_keys_checked != "")) {
112+
max_time <- max_time %>% group_by(across(all_of(epi_keys_checked)))
113+
}
114+
max_time <- max_time %>%
115+
summarise(time_value = max(time_value)) %>%
116+
pull(time_value) %>%
117+
min()
118+
return(max_time)
119+
}
120+
121+
122+
110123
#' the latency is also the amount the shift is off by
111124
#' @param sign_shift integer. 1 if lag and -1 if ahead. These represent how you
112125
#' need to shift the data to bring the 3 day lagged value to today.
113126
#' @keywords internal
114127
get_latency <- function(new_data, forecast_date, column, sign_shift, epi_keys_checked) {
115128
shift_max_date <- new_data %>%
116129
drop_na(all_of(column))
130+
if (nrow(shift_max_date) == 0) {
131+
# if everything is an NA, there's infinite latency, but shifting by that is
132+
# untenable. May as well not shift at all
133+
return(0)
134+
}
117135
# null and "" don't work in `group_by`
118-
if (!is.null(epi_keys_checked) && epi_keys_checked != "") {
119-
shift_max_date <- shift_max_date %>% group_by(get(epi_keys_checked))
136+
if (!is.null(epi_keys_checked) && all(epi_keys_checked != "")) {
137+
shift_max_date <- shift_max_date %>% group_by(across(all_of(epi_keys_checked)))
120138
}
121139
shift_max_date <- shift_max_date %>%
122140
summarise(time_value = max(time_value)) %>%
@@ -290,7 +308,8 @@ check_interminable_latency <- function(dataset, latency_table, target_columns, f
290308
#' @keywords internal
291309
#' @importFrom dplyr rowwise
292310
get_latency_table <- function(training, columns, forecast_date, latency,
293-
sign_shift, epi_keys_checked, info, terms) {
311+
sign_shift, epi_keys_checked, keys_to_ignore,
312+
info, terms) {
294313
if (is.null(columns)) {
295314
columns <- recipes_eval_select(terms, training, info)
296315
}
@@ -300,12 +319,17 @@ get_latency_table <- function(training, columns, forecast_date, latency,
300319
if (length(columns) > 0) {
301320
latency_table <- latency_table %>% filter(col_name %in% columns)
302321
}
303-
322+
training_dropped <- training %>%
323+
drop_ignored_keys(keys_to_ignore)
304324
if (is.null(latency)) {
305325
latency_table <- latency_table %>%
306326
rowwise() %>%
307327
mutate(latency = get_latency(
308-
training, forecast_date, col_name, sign_shift, epi_keys_checked
328+
training_dropped,
329+
forecast_date,
330+
col_name,
331+
sign_shift,
332+
epi_keys_checked
309333
))
310334
} else if (length(latency) > 1) {
311335
# if latency has a length, it must also have named elements.
@@ -319,7 +343,7 @@ get_latency_table <- function(training, columns, forecast_date, latency,
319343
latency_table <- latency_table %>%
320344
rowwise() %>%
321345
mutate(latency = get_latency(
322-
training, forecast_date, col_name, sign_shift, epi_keys_checked
346+
training %>% drop_ignored_keys(keys_to_ignore), forecast_date, col_name, sign_shift, epi_keys_checked
323347
))
324348
if (latency) {
325349
latency_table <- latency_table %>% mutate(latency = latency)
@@ -328,6 +352,19 @@ get_latency_table <- function(training, columns, forecast_date, latency,
328352
return(latency_table %>% ungroup())
329353
}
330354

355+
#' given a list named by key columns, remove any matching key values
356+
#' keys_to_ignore should have the form list(col_name = c("value_to_ignore", "other_value_to_ignore"))
357+
#' @keywords internal
358+
drop_ignored_keys <- function(training, keys_to_ignore) {
359+
# note that the extra parenthesis black magic is described here: https://github.com/tidyverse/dplyr/issues/6194
360+
# and is needed to bypass an incomplete port of `across` functions to `if_any`
361+
training %>%
362+
filter((dplyr::if_all(
363+
names(keys_to_ignore),
364+
~ . %nin% keys_to_ignore[[cur_column()]]
365+
)))
366+
}
367+
331368

332369
#' checks: the recipe type, whether a previous step is the relevant epi_shift,
333370
#' that either `fixed_latency` or `fixed_forecast_date` is non-null, and that
@@ -394,7 +431,7 @@ compare_bake_prep_latencies <- function(object, new_data, call = caller_env()) {
394431
)
395432
local_latency_table <- get_latency_table(
396433
new_data, object$columns, current_forecast_date, latency,
397-
get_sign(object), object$epi_keys_checked, NULL, NULL
434+
get_sign(object), object$epi_keys_checked, object$keys_to_ignore, NULL, NULL
398435
)
399436
comparison_table <- local_latency_table %>%
400437
ungroup() %>%

man/drop_ignored_keys.Rd

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/get_latency_table.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/step_adjust_latency.Rd

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)