Skip to content

Commit 9875d2c

Browse files
committed
non-shift a noop, NA robust max_time
1 parent c99b9c9 commit 9875d2c

File tree

7 files changed

+47
-44
lines changed

7 files changed

+47
-44
lines changed

R/epi_shift.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ get_sign <- function(object) {
4242
add_shifted_columns <- function(new_data, object) {
4343
grid <- object$shift_grid
4444

45+
if (nrow(object$shift_grid) == 0) {
46+
# we're not shifting any rows, so this is a no-op
47+
return(new_data)
48+
}
4549
## ensure no name clashes
4650
new_data_names <- colnames(new_data)
4751
intersection <- new_data_names %in% grid$newname

R/step_adjust_latency.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,6 @@ bake.step_adjust_latency <- function(object, new_data, ...) {
317317
if (!inherits(new_data, "epi_df") || is.null(attributes(new_data)$metadata$as_of)) {
318318
new_data <- as_epi_df(new_data, as_of = object$forecast_date, other_keys = object$metadata$other_keys %||% character())
319319
attributes(new_data)$metadata <- object$metadata
320-
attributes(new_data)$metadata$as_of <- object$forecast_date
321-
} else {
322320
compare_bake_prep_latencies(object, new_data)
323321
}
324322
if (object$method == "locf") {

R/step_epi_shift.R

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,6 @@ prep.step_epi_lag <- function(x, training, info = NULL, ...) {
192192
} else {
193193
shift_grid <- x$shift_grid
194194
}
195-
if (nrow(shift_grid) == 0) {
196-
cli_warn(
197-
c(
198-
"prepping no columns!",
199-
"{x$terms} returns no columns for this dataset."
200-
),
201-
class = "epipredict__step_epi_lag__no_columns_shifted"
202-
)
203-
}
204195

205196
step_epi_lag_new(
206197
terms = x$terms,
@@ -235,15 +226,6 @@ prep.step_epi_ahead <- function(x, training, info = NULL, ...) {
235226
} else {
236227
shift_grid <- x$shift_grid
237228
}
238-
if (nrow(shift_grid) == 0) {
239-
cli_warn(
240-
c(
241-
"prepping no columns!",
242-
"{x$terms} returns no columns for this dataset."
243-
),
244-
class = "epipredict__step_epi_ahead__no_columns_shifted"
245-
)
246-
}
247229

248230
step_epi_ahead_new(
249231
terms = x$terms,

R/utils-latency.R

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,18 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns
5050
)
5151
}
5252
}
53+
max_time <- get_max_time(new_data, epi_keys_checked, columns)
5354
# the source data determines the actual time_values
54-
# these are the non-na time_values;
55-
# get the minimum value across the checked epi_keys' maximum time values
56-
max_time <- new_data %>%
57-
select(all_of(columns)) %>%
58-
drop_na()
59-
# null and "" don't work in `group_by`
60-
if (!is.null(epi_keys_checked) && all(epi_keys_checked != "")) {
61-
max_time <- max_time %>% group_by(across(all_of(epi_keys_checked)))
62-
}
63-
max_time <- max_time %>%
64-
summarise(time_value = max(time_value)) %>%
65-
pull(time_value) %>%
66-
min()
6755
if (is.null(latency)) {
6856
forecast_date <- attributes(new_data)$metadata$as_of
6957
} else {
58+
if (is.null(max_time)) {
59+
cli_abort("max_time is null. This likely means there is one of {columns} that is all `NA`")
60+
}
7061
forecast_date <- max_time + latency
7162
}
7263
# make sure the as_of is sane
73-
if (!inherits(forecast_date, class(max_time)) & !inherits(forecast_date, "POSIXt")) {
64+
if (!inherits(forecast_date, class(new_data$time_value)) & !inherits(forecast_date, "POSIXt")) {
7465
cli_abort(
7566
paste(
7667
"the data matrix `forecast_date` value is {forecast_date}, ",
@@ -84,13 +75,13 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns
8475
if (is.null(forecast_date) || is.na(forecast_date)) {
8576
cli_warn(
8677
paste(
87-
"epi_data's `forecast_date` was {forecast_date}, setting to ",
88-
"the latest time value, {max_time}."
78+
"epi_data's `forecast_date` was `NA`, setting to ",
79+
"the latest non-`NA` time value for these columns, {max_time}."
8980
),
9081
class = "epipredict__get_forecast_date__max_time_warning"
9182
)
9283
forecast_date <- max_time
93-
} else if (forecast_date < max_time) {
84+
} else if (!is.null(max_time) && (forecast_date < max_time)) {
9485
cli_abort(
9586
paste(
9687
"`forecast_date` ({(forecast_date)}) is before the most ",
@@ -101,19 +92,46 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns
10192
)
10293
}
10394
# TODO cover the rest of the possible types for as_of and max_time...
104-
if (inherits(max_time, "Date")) {
95+
if (inherits(new_data$time_value, "Date")) {
10596
forecast_date <- as.Date(forecast_date)
10697
}
10798
return(forecast_date)
10899
}
109100

101+
get_max_time <- function(new_data, epi_keys_checked, columns) {
102+
# these are the non-na time_values;
103+
# get the minimum value across the checked epi_keys' maximum time values
104+
max_time <- new_data %>%
105+
select(all_of(columns)) %>%
106+
drop_na()
107+
if (nrow(max_time) == 0) {
108+
return(NULL)
109+
}
110+
# null and "" don't work in `group_by`
111+
if (!is.null(epi_keys_checked) && all(epi_keys_checked != "")) {
112+
max_time <- max_time %>% group_by(across(all_of(epi_keys_checked)))
113+
}
114+
max_time <- max_time %>%
115+
summarise(time_value = max(time_value)) %>%
116+
pull(time_value) %>%
117+
min()
118+
return(max_time)
119+
}
120+
121+
122+
110123
#' the latency is also the amount the shift is off by
111124
#' @param sign_shift integer. 1 if lag and -1 if ahead. These represent how you
112125
#' need to shift the data to bring the 3 day lagged value to today.
113126
#' @keywords internal
114127
get_latency <- function(new_data, forecast_date, column, sign_shift, epi_keys_checked) {
115128
shift_max_date <- new_data %>%
116129
drop_na(all_of(column))
130+
if (nrow(shift_max_date) == 0) {
131+
# if everything is an NA, there's infinite latency, but shifting by that is
132+
# untenable. May as well not shift at all
133+
return(0)
134+
}
117135
# null and "" don't work in `group_by`
118136
if (!is.null(epi_keys_checked) && all(epi_keys_checked != "")) {
119137
shift_max_date <- shift_max_date %>% group_by(across(all_of(epi_keys_checked)))

man/step_adjust_latency.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-step_adjust_latency.R

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -398,10 +398,6 @@ test_that("epi_adjust_latency correctly extends the lags when there are differen
398398
names(fit5$pre$mold$outcomes),
399399
glue::glue("ahead_{ahead}_death_rate")
400400
)
401-
latest <- get_test_data(r5, x)
402-
pred <- predict(fit5, latest)
403-
actual_solutions <- pred %>% filter(!is.na(.pred))
404-
expect_equal(actual_solutions$time_value, testing_as_of + 1)
405401

406402
# should have four predictors, including the intercept
407403
expect_equal(length(fit5$fit$fit$fit$coefficients), 6)

tests/testthat/test-step_epi_shift.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,8 @@ test_that("Check that epi_lag shifts applies the shift", {
6666
# Should have four predictors, including the intercept
6767
expect_equal(length(fit5$fit$fit$fit$coefficients), 4)
6868
})
69+
70+
test_that("Shifting nothing is a no-op", {
71+
expect_no_error(noop <- epi_recipe(x) %>% step_epi_ahead(ahead = 3) %>% prep(x) %>% bake(x))
72+
expect_equal(noop, x)
73+
})

0 commit comments

Comments
 (0)