Skip to content

Commit 3d7fbe0

Browse files
committed
include climate, only calculate necessary days
1 parent e8bc12a commit 3d7fbe0

File tree

6 files changed

+156
-74
lines changed

6 files changed

+156
-74
lines changed

R/autoplot.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,12 @@ autoplot.epi_workflow <- function(
113113
keys <- c("geo_value", "time_value", "key")
114114
mold_roles <- names(mold$extras$roles)
115115
# extract the relevant column names for plotting
116-
old_name_y <- unlist(strsplit(names(y), "_"))
117-
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
116+
if (starts_with_impl("ahead_", names(y)) || starts_with_impl("lag_", names(y))) {
117+
old_name_y <- unlist(strsplit(names(y), "_"))
118+
new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_")
119+
} else {
120+
new_name_y <- names(y)
121+
}
118122
if (is.null(plot_data)) {
119123
# the outcome has shifted, so we need to shift it forward (or back)
120124
# by the corresponding amount

R/climatological_forecaster.R

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,42 @@ climatological_forecaster <- function(epi_data,
115115
mean = function(x, w) mean(x, na.rm = TRUE),
116116
median = function(x, w) stats::median(x, na.rm = TRUE)
117117
)
118-
# get the point predictions
119118
keys <- key_colnames(epi_data, exclude = "time_value")
120-
epi_data <- epi_data %>% mutate(.idx = time_aggr(time_value), .weights = 1)
121-
climate_center <- epi_data %>%
119+
# Get the prediction geo and .idx for the target date(s)
120+
predictions <- epi_data %>%
121+
select(all_of(keys)) %>%
122+
dplyr::distinct() %>%
123+
mutate(forecast_date = forecast_date, .idx = time_aggr(forecast_date))
124+
predictions <-
125+
map(horizon, ~ {
126+
predictions %>%
127+
mutate(.idx = .idx + .x, target_date = forecast_date + ttype_dur(.x))
128+
}) %>%
129+
purrr::list_rbind() %>%
130+
mutate(
131+
.idx = .idx %% modulus,
132+
.idx = dplyr::case_when(.idx == 0 ~ modulus, TRUE ~ .idx)
133+
)
134+
# get the distinct .idx for the target date(s)
135+
distinct_target_idx <- predictions$.idx %>% unique()
136+
# get all of the idx's within the window of the target .idxs
137+
entries <- map(distinct_target_idx, \(idx) within_window(idx, window_size, modulus)) %>%
138+
do.call(c, .) %>%
139+
unique()
140+
# for the center, we need those within twice the window, since for each point
141+
# we're subtracting out the center to generate the quantiles
142+
entries_double_window <- map(entries, \(idx) within_window(idx, window_size, modulus)) %>%
143+
do.call(c, .) %>%
144+
unique()
145+
146+
epi_data_target <-
147+
epi_data %>%
148+
mutate(.idx = time_aggr(time_value), .weights = 1)
149+
# get the point predictions
150+
climate_center <-
151+
epi_data_target %>%
152+
filter(.idx %in% entries_double_window) %>%
153+
mutate(.idx = time_aggr(time_value), .weights = 1) %>%
122154
select(.idx, .weights, all_of(c(outcome, keys))) %>%
123155
dplyr::reframe(
124156
roll_modular_multivec(
@@ -136,7 +168,10 @@ climatological_forecaster <- function(epi_data,
136168
probs = args_list$quantile_levels, na.rm = TRUE, type = 8
137169
)))
138170
}
139-
climate_quantiles <- epi_data %>%
171+
# add on the centers and subtract them out before computing the quantiles
172+
climate_quantiles <-
173+
epi_data_target %>%
174+
filter(.idx %in% entries) %>%
140175
left_join(climate_center, by = c(".idx", keys)) %>%
141176
mutate({{ outcome }} := !!sym_outcome - .pred) %>%
142177
select(.idx, .weights, all_of(c(outcome, args_list$quantile_by_key))) %>%
@@ -147,31 +182,17 @@ climatological_forecaster <- function(epi_data,
147182
),
148183
.by = all_of(args_list$quantile_by_key)
149184
) %>%
150-
rename(.pred_distn = climate_pred) %>%
151-
mutate(.pred_distn = hardhat::quantile_pred(do.call(rbind, .pred_distn), args_list$quantile_levels))
185+
mutate(.pred_distn = hardhat::quantile_pred(do.call(rbind, climate_pred), args_list$quantile_levels)) %>%
186+
select(-climate_pred)
152187
# combine them together
153188
climate_table <- climate_center %>%
154-
left_join(climate_quantiles, by = c(".idx", args_list$quantile_by_key)) %>%
189+
inner_join(climate_quantiles, by = c(".idx", args_list$quantile_by_key)) %>%
155190
mutate(.pred_distn = .pred_distn + .pred)
156-
# create the predictions
157-
predictions <- epi_data %>%
158-
select(all_of(keys)) %>%
159-
dplyr::distinct() %>%
160-
mutate(forecast_date = forecast_date, .idx = time_aggr(forecast_date))
161-
predictions <- map(horizon, ~ {
162-
predictions %>%
163-
mutate(.idx = .idx + .x, target_date = forecast_date + ttype_dur(.x))
164-
}) %>%
165-
purrr::list_rbind() %>%
166-
mutate(
167-
.idx = .idx %% modulus,
168-
.idx = dplyr::case_when(.idx == 0 ~ modulus, TRUE ~ .idx)
169-
) %>%
191+
predictions <- predictions %>%
170192
left_join(climate_table, by = c(".idx", keys)) %>%
171193
select(-.idx)
172194
if (args_list$nonneg) {
173-
predictions <- mutate(
174-
predictions,
195+
predictions <- predictions %>% mutate(
175196
.pred = snap(.pred, 0, Inf),
176197
.pred_distn = snap(.pred_distn, 0, Inf)
177198
)

R/step_climate.R

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -338,57 +338,74 @@ print.step_climate <- function(x, width = max(20, options()$width - 30), ...) {
338338
}
339339

340340
#' group col by .idx values and sum windows around each .idx value
341-
#' @param .idx the relevant periodic part of time value, e.g. the week number
342-
#' @param col the list of values indexed by `.idx`
343-
#' @param weights how much to weigh each particular datapoint
344-
#' @param aggr the aggregation function, probably Quantile, mean or median
341+
#' @param idx_in the relevant periodic part of time value, e.g. the week number,
342+
#' limited to the relevant range
343+
#' @param col the list of values indexed by `idx_in`
344+
#' @param weights how much to weigh each particular datapoint (also indexed by
345+
#' `idx_in`)
346+
#' @param aggr the aggregation function, probably Quantile, mean, or median
345347
#' @param window_size the number of .idx entries before and after to include in
346348
#' the aggregation
347-
#' @param modulus the maximum value of `.idx`
349+
#' @param modulus the number of days/weeks/months in the year, not including any
350+
#' leap days/weeks
348351
#' @importFrom lubridate %m-%
349352
#' @keywords internal
350-
roll_modular_multivec <- function(col, .idx, weights, aggr, window_size, modulus) {
351-
tib <- tibble(col = col, weights = weights, .idx = .idx) |>
353+
roll_modular_multivec <- function(col, idx_in, weights, aggr, window_size, modulus) {
354+
# make a tibble where data gives the list of all datapoints with the
355+
# corresponding .idx
356+
tib <- tibble(col = col, weights = weights, .idx = idx_in) |>
352357
arrange(.idx) |>
353358
tidyr::nest(data = c(col, weights), .by = .idx)
354-
out <- double(modulus + 1)
355-
for (iter in seq_along(out)) {
356-
# +1 from 1-indexing
357-
entries <- (iter - window_size):(iter + window_size) %% modulus
358-
entries[entries == 0] <- modulus
359-
# note that because we are 1-indexing, we're looking for indices that are 1
360-
# larger than the actual day/week in the year
361-
if (modulus == 365) {
362-
# we need to grab just the window around the leap day on the leap day
363-
if (iter == 366) {
364-
# there's an extra data point in front of the leap day
365-
entries <- (59 - window_size):(59 + window_size - 1) %% modulus
366-
entries[entries == 0] <- modulus
367-
# adding in the leap day itself
368-
entries <- c(entries, 999)
369-
} else if ((59 %in% entries) || (60 %in% entries)) {
370-
# if we're on the Feb/March boundary for daily data, we need to add in the
371-
# leap day data
372-
entries <- c(entries, 999)
373-
}
374-
} else if (modulus == 52) {
375-
# we need to grab just the window around the leap week on the leap week
376-
if (iter == 53) {
377-
entries <- (53 - window_size):(53 + window_size - 1) %% 52
378-
entries[entries == 0] <- 52
379-
entries <- c(entries, 999)
380-
} else if ((52 %in% entries) || (1 %in% entries)) {
381-
# if we're on the year boundary for weekly data, we need to add in the
382-
# leap week data (which is the extra week at the end)
383-
entries <- c(entries, 999)
384-
}
385-
}
386-
out[iter] <- with(
359+
# storage for the results, includes all possible time indexes
360+
out <- tibble(.idx = c(1:modulus, 999), climate_pred = double(modulus + 1))
361+
for (tib_idx in tib$.idx) {
362+
entries <- within_window(tib_idx, window_size, modulus)
363+
out$climate_pred[out$.idx == tib_idx] <- with(
387364
purrr::list_rbind(tib %>% filter(.idx %in% entries) %>% pull(data)),
388365
aggr(col, weights)
389366
)
390367
}
391-
tibble(.idx = unique(tib$.idx), climate_pred = out[seq_len(nrow(tib))])
368+
# filter to only the ones we actually computed
369+
out %>% filter(.idx %in% idx_in)
370+
}
371+
372+
#' generate the idx values within `window_size` of `target_idx` given that our
373+
#' time value is of the type matching modulus
374+
#' @param target_idx the time index which we're drawing the window around
375+
#' @param window_size the size of the window on one side of `target_idx`
376+
#' @param modulus the number of days/weeks/months in the year, not including any leap days/weeks
377+
#' @keywords internal
378+
within_window <- function(target_idx, window_size, modulus) {
379+
entries <- (target_idx - window_size):(target_idx + window_size) %% modulus
380+
entries[entries == 0] <- modulus
381+
# note that because we are 1-indexing, we're looking for indices that are 1
382+
# larger than the actual day/week in the year
383+
if (modulus == 365) {
384+
# we need to grab just the window around the leap day on the leap day
385+
if (target_idx == 999) {
386+
# there's an extra data point in front of the leap day
387+
entries <- (59 - window_size):(59 + window_size - 1) %% modulus
388+
entries[entries == 0] <- modulus
389+
# adding in the leap day itself
390+
entries <- c(entries, 999)
391+
} else if ((59 %in% entries) || (60 %in% entries)) {
392+
# if we're on the Feb/March boundary for daily data, we need to add in the
393+
# leap day data
394+
entries <- c(entries, 999)
395+
}
396+
} else if (modulus == 52) {
397+
# we need to grab just the window around the leap week on the leap week
398+
if (target_idx == 999) {
399+
entries <- (53 - window_size):(53 + window_size - 1) %% 52
400+
entries[entries == 0] <- 52
401+
entries <- c(entries, 999)
402+
} else if ((52 %in% entries) || (1 %in% entries)) {
403+
# if we're on the year boundary for weekly data, we need to add in the
404+
# leap week data (which is the extra week at the end)
405+
entries <- c(entries, 999)
406+
}
407+
}
408+
entries
392409
}
393410

394411

man/roll_modular_multivec.Rd

Lines changed: 9 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-step_climate.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ test_that("prep/bake steps create the correct training data with an incomplete y
110110
r <- epi_recipe(x) %>% step_climate(y, time_type = "epiweek")
111111
p <- prep(r, x)
112112

113-
expected_res <- tibble(.idx = c(1:44, 999), climate_y = c(2, 3, 3, 4:25, 25, 25, 25:12, 12, 11, 11, 10))
113+
expected_res <- tibble(.idx = c(1:44, 999), climate_y = c(2, 3, 3, 4:25, 25, 25, 25:12, 12, 11, 11, 2))
114114
expect_equal(p$steps[[1]]$climate_table, expected_res)
115115

116116
b <- bake(p, new_data = NULL)

vignettes/epipredict.Rmd

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ We currently provide the following basic forecasters:
4545

4646
* _Flatline forecaster_: predicts as the median the most recently seen value
4747
with increasingly wide quantiles.
48+
* _Climatological forecaster_: predicts the median and quantiles based on the historical values around the same date in previous years.
4849
* _Autoregressive forecaster_: fits a model (e.g. linear regression) on
4950
lagged data to predict quantiles for continuous values.
5051
* _Autoregressive classifier_: fits a model (e.g. logistic regression) on
@@ -243,7 +244,6 @@ all_flatlines <- lapply(
243244
outcome = "death_rate",
244245
args_list = flatline_args_list(
245246
ahead = days_ahead,
246-
quantile_levels = c(0.05, 0.5, 0.95)
247247
)
248248
)
249249
}
@@ -262,6 +262,43 @@ autoplot(
262262
Note that the `cdc_baseline_forecaster` is a slight modification of this method
263263
for use in [the CDC COVID19 Forecasting Hub](https://covid19forecasthub.org/).
264264

265+
### `climatological_forecaster()`
266+
A different kind of baseline, the `climatological_forecaster()` forecasts the
267+
point forecast and quantiles based on the historical values for this time of
268+
year, rather than extrapolating from recent values.
269+
For example, on the same dataset as above:
270+
```{r make-climatological-forecast, warning=FALSE}
271+
all_climate <- climatological_forecaster(
272+
covid_case_death_rates_extended |>
273+
filter(time_value <= forecast_date, geo_value %in% used_locations),
274+
outcome = "death_rate",
275+
args_list = climate_args_list(
276+
forecast_horizon = seq(0, 28),
277+
window_size = 14,
278+
time_type = "day",
279+
forecast_date = forecast_date
280+
)
281+
)
282+
workflow <- all_climate$epi_workflow
283+
results <- all_climate$predictions
284+
autoplot(
285+
object = workflow,
286+
predictions = results,
287+
plot_data = covid_case_death_rates_extended |> filter(geo_value %in% used_locations, time_value > "2021-07-01")
288+
)
289+
```
290+
291+
Note that we're using `covid_case_death_rates_extended` rather than
292+
`covid_case_death_rates`, since it starts in March of 2020 rather than December.
293+
Without at least a year's worth of historical data, it is impossible to do a
294+
climatological model.
295+
Even with only one year as we have here the resulting forecasts are unreliable.
296+
297+
One feature of the climatological baseline is that it forecasts multiple aheads
298+
simultaneously.
299+
This is possible for `arx_forecaster()`, but only using `trainer =
300+
smooth_quantile_reg()`, which is built to handle multiple aheads simultaneously.
301+
265302
### `arx_classifier()`
266303

267304
The most complicated of the canned forecasters, `arx_classifier` first

0 commit comments

Comments
 (0)