Skip to content

Commit 742c31a

Browse files
dsweber2dshemetov
andauthored
feat: season summary (#197)
* season summary 2024-2025 * revision summary notebook * exploration summary 2023-2024 * describe all forecasters tried * future work * lint: remove priority target args, as they're deprecated * repo: renv * feat: simplify daily_to_weekly_archive * dont use epi_slide, group instead * it's faster and simpler --------- Co-authored-by: Dmitry Shemetov <[email protected]>
1 parent adce1f9 commit 742c31a

39 files changed

+1990
-494
lines changed

Makefile

+9-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ test:
1111
run:
1212
Rscript scripts/run.R
1313

14+
run-nohup:
15+
nohup Rscript scripts/run.R &
16+
17+
run-nohup-restarting:
18+
scripts/hardRestarting.sh &
19+
1420
prod-covid:
1521
export TAR_RUN_PROJECT=covid_hosp_prod; Rscript scripts/run.R
1622

@@ -65,12 +71,6 @@ get-nwss:
6571
python nwss_covid_export.py; \
6672
python nwss_influenza_export.py
6773

68-
run-nohup:
69-
nohup Rscript scripts/run.R &
70-
71-
run-nohup-restarting:
72-
scripts/hardRestarting.sh &
73-
7474
sync:
7575
Rscript -e "source('R/sync_aws.R'); sync_aws()"
7676

@@ -98,3 +98,6 @@ get-flu-prod-errors:
9898

9999
get-covid-prod-errors:
100100
Rscript -e "suppressPackageStartupMessages(source(here::here('R', 'load_all.R'))); get_targets_errors(project = 'covid_hosp_prod')"
101+
102+
summary_reports:
103+
Rscript scripts/summary_reports.R

R/aux_data_utils.R

+15-56
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,15 @@ daily_to_weekly <- function(epi_df, agg_method = c("sum", "mean"), keys = "geo_v
213213
#' @param epi_arch the archive to aggregate.
214214
#' @param agg_columns the columns to aggregate.
215215
#' @param agg_method the method to use to aggregate the data, one of "sum" or "mean".
216-
#' @param day_of_week the day of the week to use as the reference day.
217-
#' @param day_of_week_end the day of the week to use as the end of the week.
216+
#' @param week_reference the day of the week to use as the reference day (Wednesday is default).
217+
#' Note that this is 1-indexed, so 1 = Sunday, 2 = Monday, ..., 7 = Saturday.
218+
#' @param week_start the day of the week to use as the start of the week (Sunday is default).
219+
#' Note that this is 1-indexed, so 1 = Sunday, 2 = Monday, ..., 7 = Saturday.
218220
daily_to_weekly_archive <- function(epi_arch,
219221
agg_columns,
220222
agg_method = c("sum", "mean"),
221-
day_of_week = 4L,
222-
day_of_week_end = 7L) {
223+
week_reference = 4L,
224+
week_start = 7L) {
223225
# How to aggregate the windowed data.
224226
agg_method <- arg_match(agg_method)
225227
# The columns we will later group by when aggregating.
@@ -230,67 +232,24 @@ daily_to_weekly_archive <- function(epi_arch,
230232
sort()
231233
# Choose a fast function to use to slide and aggregate.
232234
if (agg_method == "sum") {
233-
slide_fun <- epi_slide_sum
235+
# If the week is complete, this is equivalent to the sum. If the week is not
236+
# complete, this is equivalent to 7/(number of days in the week) * the sum,
237+
# which should be a decent approximation.
238+
agg_fun <- \(x) 7 * mean(x, na.rm = TRUE)
234239
} else if (agg_method == "mean") {
235-
slide_fun <- epi_slide_mean
240+
agg_fun <- \(x) mean(x, na.rm = TRUE)
236241
}
237242
# Slide over the versions and aggregate.
238243
epix_slide(
239244
epi_arch,
240245
.versions = ref_time_values,
241246
function(x, group_keys, ref_time) {
242-
# The last day of the week we will slide over.
243-
ref_time_last_week_end <- floor_date(ref_time, "week", day_of_week_end - 1)
244-
245-
# To find the days we will slide over, we need to find the first and last
246-
# complete weeks of data. Get the max and min times, and then find the
247-
# first and last complete weeks of data.
248-
min_time <- min(x$time_value)
249-
max_time <- max(x$time_value)
250-
251-
# Let's determine if the min and max times are in the same week.
252-
ceil_min_time <- ceiling_date(min_time, "week", week_start = day_of_week_end - 1)
253-
ceil_max_time <- ceiling_date(max_time, "week", week_start = day_of_week_end - 1)
254-
255-
# If they're not in the same week, this means we have at least one
256-
# complete week of data to slide over.
257-
if (ceil_min_time < ceil_max_time) {
258-
valid_slide_days <- seq.Date(
259-
from = ceiling_date(min_time, "week", week_start = day_of_week_end - 1),
260-
to = floor_date(max_time, "week", week_start = day_of_week_end - 1),
261-
by = 7L
262-
)
263-
} else {
264-
# This is the degenerate case, where we have about 1 week or less of
265-
# data. In this case, we opt to return nothing for two reasons:
266-
# 1. in most cases here, the data is incomplete for a single week,
267-
# 2. if the data is complete, a single week of data is not enough to
268-
# reasonably perform any kind of aggregation.
269-
return(tibble())
270-
}
271-
272-
# If the last day of the week is not the end of the week, add it to the
273-
# list of valid slide days (this will produce an incomplete slide, but
274-
# that's fine for us, since it should only be 1 day, historically.)
275-
if (wday(max_time) != day_of_week_end) {
276-
valid_slide_days <- c(valid_slide_days, max_time)
277-
}
278-
279247
# Slide over the days and aggregate.
280248
x %>%
281-
group_by(across(all_of(keys))) %>%
282-
slide_fun(
283-
agg_columns,
284-
.window_size = 7L,
285-
na.rm = TRUE,
286-
.ref_time_values = valid_slide_days
287-
) %>%
288-
select(-all_of(agg_columns)) %>%
289-
rename_with(~ gsub("slide_value_", "", .x)) %>%
290-
rename_with(~ gsub("_7dsum", "", .x)) %>%
291-
# Round all dates to reference day of the week. These will get
292-
# de-duplicated by compactify in as_epi_archive below.
293-
mutate(time_value = round_date(time_value, "week", day_of_week - 1)) %>%
249+
mutate(week_start = ceiling_date(time_value, "week", week_start = week_start)-1) %>%
250+
summarize(across(all_of(agg_columns), agg_fun), .by = all_of(c(keys, "week_start"))) %>%
251+
mutate(time_value = round_date(week_start, "week", week_reference - 1)) %>%
252+
select(-week_start) %>%
294253
as_tibble()
295254
}
296255
) %>%

R/forecasters/data_validation.R

-17
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ confirm_sufficient_data <- function(epi_data, ahead, args_input, outcome, extra_
6868
# TODO: Buffer should probably be 2 * n(lags) * n(predictors). But honestly,
6969
# this needs to be fixed in epipredict itself, see
7070
# https://github.com/cmu-delphi/epipredict/issues/106.
71-
if (identical(extra_sources, "")) {
72-
extra_sources <- character(0L)
73-
}
7471
has_no_last_nas <- epi_data %>%
7572
drop_na(c(!!outcome, !!!extra_sources)) %>%
7673
group_by(geo_value) %>%
@@ -106,17 +103,3 @@ filter_minus_one_ahead <- function(epi_data, ahead) {
106103
}
107104
epi_data
108105
}
109-
110-
#' Unwrap an argument if it's a list of length 1
111-
#'
112-
#' Many of our arguments to the forecasters come as lists not because we expect
113-
#' them that way, but as a byproduct of tibble and expand_grid.
114-
unwrap_argument <- function(arg, default_trigger = "", default = character(0L)) {
115-
if (is.list(arg) && length(arg) == 1) {
116-
arg <- arg[[1]]
117-
}
118-
if (identical(arg, default_trigger)) {
119-
return(default)
120-
}
121-
return(arg)
122-
}

R/forecasters/ensemble_average.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
ensemble_average <- function(epi_data,
2626
forecasts,
2727
outcome,
28-
extra_sources = "",
28+
extra_sources = character(),
2929
ensemble_args = list(),
3030
ensemble_args_names = NULL) {
3131
# unique parameters must be buried in ensemble_args so that the generic function signature is stable

R/forecasters/forecaster_climatological.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#'
33
climate_linear_ensembled <- function(epi_data,
44
outcome,
5-
extra_sources = "",
5+
extra_sources = character(),
66
ahead = 7,
77
trainer = parsnip::linear_reg(),
88
quantile_levels = covidhub_probs(),
@@ -22,8 +22,7 @@ climate_linear_ensembled <- function(epi_data,
2222
nonlin_method <- arg_match(nonlin_method)
2323

2424
epi_data <- validate_epi_data(epi_data)
25-
extra_sources <- unwrap_argument(extra_sources)
26-
trainer <- unwrap_argument(trainer)
25+
extra_sources <- unlist(extra_sources)
2726

2827
args_list <- list(...)
2928
ahead <- as.integer(ahead / 7)

R/forecasters/forecaster_flatline.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,15 @@
1010
#' @export
1111
flatline_fc <- function(epi_data,
1212
outcome,
13-
extra_sources = "",
13+
extra_sources = character(),
1414
ahead = 1,
1515
trainer = parsnip::linear_reg(),
1616
quantile_levels = covidhub_probs(),
1717
filter_source = "",
1818
filter_agg_level = "",
1919
...) {
2020
epi_data <- validate_epi_data(epi_data)
21-
extra_sources <- unwrap_argument(extra_sources)
22-
trainer <- unwrap_argument(trainer)
21+
extra_sources <- unlist(extra_sources)
2322

2423
# perform any preprocessing not supported by epipredict
2524
epi_data %<>% filter_extraneous(filter_source, filter_agg_level)

R/forecasters/forecaster_flusion.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
flusion <- function(epi_data,
22
outcome,
3-
extra_sources = "",
3+
extra_sources = character(),
44
ahead = 7,
55
pop_scaling = FALSE,
66
trainer = rand_forest(
@@ -24,8 +24,7 @@ flusion <- function(epi_data,
2424
derivative_estimator <- arg_match(derivative_estimator)
2525

2626
epi_data <- validate_epi_data(epi_data)
27-
extra_sources <- unwrap_argument(extra_sources)
28-
trainer <- unwrap_argument(trainer)
27+
extra_sources <- unlist(extra_sources)
2928

3029
# perform any preprocessing not supported by epipredict
3130
args_input <- list(...)

R/forecasters/forecaster_no_recent_outcome.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#' it may whiten any old data as the outcome
33
no_recent_outcome <- function(epi_data,
44
outcome,
5-
extra_sources = "",
5+
extra_sources = character(),
66
ahead = 7,
77
pop_scaling = FALSE,
88
trainer = epipredict::quantile_reg(),
@@ -24,8 +24,7 @@ no_recent_outcome <- function(epi_data,
2424
week_method <- arg_match(week_method)
2525

2626
epi_data <- validate_epi_data(epi_data)
27-
extra_sources <- unwrap_argument(extra_sources)
28-
trainer <- unwrap_argument(trainer)
27+
extra_sources <- unlist(extra_sources)
2928

3029
# this is for the case where there are multiple sources in the same column
3130
epi_data %<>% filter_extraneous(filter_source, filter_agg_level)

R/forecasters/forecaster_scaled_pop.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
#' @export
4848
scaled_pop <- function(epi_data,
4949
outcome,
50-
extra_sources = "",
50+
extra_sources = character(),
5151
ahead = 1,
5252
pop_scaling = TRUE,
5353
drop_non_seasons = FALSE,
@@ -64,8 +64,7 @@ scaled_pop <- function(epi_data,
6464
nonlin_method <- arg_match(nonlin_method)
6565

6666
epi_data <- validate_epi_data(epi_data)
67-
extra_sources <- unwrap_argument(extra_sources)
68-
trainer <- unwrap_argument(trainer)
67+
extra_sources <- unlist(extra_sources)
6968

7069
# perform any preprocessing not supported by epipredict
7170
#

R/forecasters/forecaster_scaled_pop_seasonal.R

+3-7
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
scaled_pop_seasonal <- function(
3939
epi_data,
4040
outcome,
41-
extra_sources = "",
41+
extra_sources = character(),
4242
ahead = 1,
4343
pop_scaling = TRUE,
4444
drop_non_seasons = FALSE,
@@ -61,13 +61,9 @@ scaled_pop_seasonal <- function(
6161
nonlin_method <- arg_match(nonlin_method)
6262

6363
epi_data <- validate_epi_data(epi_data)
64-
extra_sources <- unwrap_argument(extra_sources)
65-
trainer <- unwrap_argument(trainer)
64+
extra_sources <- unlist(extra_sources)
6665

67-
if (typeof(seasonal_method) == "list") {
68-
seasonal_method <- seasonal_method[[1]]
69-
}
70-
if (all(seasonal_method == c("none", "flu", "covid", "indicator", "window", "climatological"))) {
66+
if (identical(seasonal_method, c("none", "flu", "covid", "indicator", "window", "climatological"))) {
7167
seasonal_method <- "none"
7268
}
7369
# perform any preprocessing not supported by epipredict

R/forecasters/forecaster_smoothed_scaled.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
#' @export
5252
smoothed_scaled <- function(epi_data,
5353
outcome,
54-
extra_sources = "",
54+
extra_sources = character(),
5555
ahead = 1,
5656
pop_scaling = TRUE,
5757
trainer = parsnip::linear_reg(),
@@ -73,8 +73,7 @@ smoothed_scaled <- function(epi_data,
7373
nonlin_method <- arg_match(nonlin_method)
7474

7575
epi_data <- validate_epi_data(epi_data)
76-
extra_sources <- unwrap_argument(extra_sources)
77-
trainer <- unwrap_argument(trainer)
76+
extra_sources <- unlist(extra_sources)
7877

7978
# perform any preprocessing not supported by epipredict
8079
#

R/forecasters/formatters.R

+14-11
Original file line numberDiff line numberDiff line change
@@ -72,24 +72,27 @@ format_flusight <- function(pred, disease = c("flu", "covid")) {
7272
}
7373

7474
format_scoring_utils <- function(forecasts_and_ensembles, disease = c("flu", "covid")) {
75-
forecasts_and_ensembles %>%
76-
filter(!grepl("region.*", geo_value)) %>%
77-
mutate(
78-
reference_date = get_forecast_reference_date(forecast_date),
79-
target = glue::glue("wk inc {disease} hosp"),
80-
horizon = as.integer(floor((target_end_date - reference_date) / 7)),
81-
output_type = "quantile",
82-
output_type_id = quantile,
83-
value = value
84-
) %>%
75+
# dplyr here was unreasonably slow on 1m+ rows, so replacing with direct access
76+
fc_ens <- forecasts_and_ensembles
77+
fc_ens <- fc_ens[!grepl("region.*", forecasts_and_ensembles$geo_value), ]
78+
fc_ens[, "reference_date"] <- get_forecast_reference_date(fc_ens$forecast_date)
79+
fc_ens[, "target"] <- glue::glue("wk inc {disease} hosp")
80+
fc_ens[, "horizon"] <- as.integer(floor((fc_ens$target_end_date - fc_ens$reference_date) / 7))
81+
fc_ens[, "output_type"] <- "quantile"
82+
fc_ens[, "output_type_id"] <- fc_ens$quantile
83+
fc_ens %>%
8584
left_join(
8685
get_population_data() %>%
8786
select(state_id, state_code),
8887
by = c("geo_value" = "state_id")
8988
) %>%
9089
rename(location = state_code, model_id = forecaster) %>%
9190
select(reference_date, target, horizon, target_end_date, location, output_type, output_type_id, value, model_id) %>%
92-
drop_na()
91+
drop_na() %>%
92+
arrange(location, target_end_date, reference_date, output_type_id) %>%
93+
group_by(model_id, location, target_end_date, reference_date) %>%
94+
mutate(value = sort(value)) %>%
95+
ungroup()
9396
}
9497

9598
#' The quantile levels used by the covidhub repository

R/imports.R

+3
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ library(crew)
99
library(data.table)
1010
library(dplyr)
1111
library(DT)
12+
options(DT.options = list(scrollX = TRUE))
1213
library(epidatr)
1314
library(epipredict)
1415
library(epiprocess)
1516
library(ggplot2)
1617
library(glue)
1718
library(grf)
1819
library(here)
20+
library(httpgd)
1921
if (Sys.getenv("COVID_SUBMISSION_DIRECTORY", "cache") != "cache") {
2022
library(hubValidations)
2123
}
@@ -36,6 +38,7 @@ library(recipes)
3638
library(renv)
3739
library(rlang)
3840
library(rspm)
41+
library(scales)
3942
library(scoringutils)
4043
library(slider)
4144
library(stringr)

0 commit comments

Comments
 (0)