Skip to content

Commit 3386797

Browse files
committed
fix: a very complicated way to fix a bug
1 parent e065f8d commit 3386797

15 files changed

+40
-84
lines changed

R/forecasters/data_validation.R

-14
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,3 @@ filter_minus_one_ahead <- function(epi_data, ahead) {
106106
}
107107
epi_data
108108
}
109-
110-
#' Unwrap an argument if it's a list of length 1
111-
#'
112-
#' Many of our arguments to the forecasters come as lists not because we expect
113-
#' them that way, but as a byproduct of tibble and expand_grid.
114-
unwrap_argument <- function(arg, default_trigger = "", default = character(0L)) {
115-
if (is.list(arg) && length(arg) == 1) {
116-
arg <- arg[[1]]
117-
}
118-
if (identical(arg, default_trigger)) {
119-
return(default)
120-
}
121-
return(arg)
122-
}

R/forecasters/ensemble_average.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
ensemble_average <- function(epi_data,
2626
forecasts,
2727
outcome,
28-
extra_sources = "",
28+
extra_sources = character(),
2929
ensemble_args = list(),
3030
ensemble_args_names = NULL) {
3131
# unique parameters must be buried in ensemble_args so that the generic function signature is stable

R/forecasters/forecaster_climatological.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#'
33
climate_linear_ensembled <- function(epi_data,
44
outcome,
5-
extra_sources = "",
5+
extra_sources = character(),
66
ahead = 7,
77
trainer = parsnip::linear_reg(),
88
quantile_levels = covidhub_probs(),
@@ -22,8 +22,7 @@ climate_linear_ensembled <- function(epi_data,
2222
nonlin_method <- arg_match(nonlin_method)
2323

2424
epi_data <- validate_epi_data(epi_data)
25-
extra_sources <- unwrap_argument(extra_sources)
26-
trainer <- unwrap_argument(trainer)
25+
extra_sources <- unlist(extra_sources)
2726

2827
args_list <- list(...)
2928
ahead <- as.integer(ahead / 7)

R/forecasters/forecaster_flatline.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,15 @@
1010
#' @export
1111
flatline_fc <- function(epi_data,
1212
outcome,
13-
extra_sources = "",
13+
extra_sources = character(),
1414
ahead = 1,
1515
trainer = parsnip::linear_reg(),
1616
quantile_levels = covidhub_probs(),
1717
filter_source = "",
1818
filter_agg_level = "",
1919
...) {
2020
epi_data <- validate_epi_data(epi_data)
21-
extra_sources <- unwrap_argument(extra_sources)
22-
trainer <- unwrap_argument(trainer)
21+
extra_sources <- unlist(extra_sources)
2322

2423
# perform any preprocessing not supported by epipredict
2524
epi_data %<>% filter_extraneous(filter_source, filter_agg_level)

R/forecasters/forecaster_flusion.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
flusion <- function(epi_data,
22
outcome,
3-
extra_sources = "",
3+
extra_sources = character(),
44
ahead = 7,
55
pop_scaling = FALSE,
66
trainer = rand_forest(
@@ -24,8 +24,7 @@ flusion <- function(epi_data,
2424
derivative_estimator <- arg_match(derivative_estimator)
2525

2626
epi_data <- validate_epi_data(epi_data)
27-
extra_sources <- unwrap_argument(extra_sources)
28-
trainer <- unwrap_argument(trainer)
27+
extra_sources <- unlist(extra_sources)
2928

3029
# perform any preprocessing not supported by epipredict
3130
args_input <- list(...)

R/forecasters/forecaster_no_recent_outcome.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#' it may whiten any old data as the outcome
33
no_recent_outcome <- function(epi_data,
44
outcome,
5-
extra_sources = "",
5+
extra_sources = character(),
66
ahead = 7,
77
pop_scaling = FALSE,
88
trainer = epipredict::quantile_reg(),
@@ -24,8 +24,7 @@ no_recent_outcome <- function(epi_data,
2424
week_method <- arg_match(week_method)
2525

2626
epi_data <- validate_epi_data(epi_data)
27-
extra_sources <- unwrap_argument(extra_sources)
28-
trainer <- unwrap_argument(trainer)
27+
extra_sources <- unlist(extra_sources)
2928

3029
# this is for the case where there are multiple sources in the same column
3130
epi_data %<>% filter_extraneous(filter_source, filter_agg_level)

R/forecasters/forecaster_scaled_pop.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
#' @export
4848
scaled_pop <- function(epi_data,
4949
outcome,
50-
extra_sources = "",
50+
extra_sources = character(),
5151
ahead = 1,
5252
pop_scaling = TRUE,
5353
drop_non_seasons = FALSE,
@@ -64,8 +64,7 @@ scaled_pop <- function(epi_data,
6464
nonlin_method <- arg_match(nonlin_method)
6565

6666
epi_data <- validate_epi_data(epi_data)
67-
extra_sources <- unwrap_argument(extra_sources)
68-
trainer <- unwrap_argument(trainer)
67+
extra_sources <- unlist(extra_sources)
6968

7069
# perform any preprocessing not supported by epipredict
7170
#

R/forecasters/forecaster_scaled_pop_seasonal.R

+2-6
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
scaled_pop_seasonal <- function(
3939
epi_data,
4040
outcome,
41-
extra_sources = "",
41+
extra_sources = character(),
4242
ahead = 1,
4343
pop_scaling = TRUE,
4444
drop_non_seasons = FALSE,
@@ -61,12 +61,8 @@ scaled_pop_seasonal <- function(
6161
nonlin_method <- arg_match(nonlin_method)
6262

6363
epi_data <- validate_epi_data(epi_data)
64-
extra_sources <- unwrap_argument(extra_sources)
65-
trainer <- unwrap_argument(trainer)
64+
extra_sources <- unlist(extra_sources)
6665

67-
if (typeof(seasonal_method) == "list") {
68-
seasonal_method <- seasonal_method[[1]]
69-
}
7066
if (all(seasonal_method == c("none", "flu", "covid", "indicator", "window", "climatological"))) {
7167
seasonal_method <- "none"
7268
}

R/forecasters/forecaster_smoothed_scaled.R

+2-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
#' @export
5252
smoothed_scaled <- function(epi_data,
5353
outcome,
54-
extra_sources = "",
54+
extra_sources = character(),
5555
ahead = 1,
5656
pop_scaling = TRUE,
5757
trainer = parsnip::linear_reg(),
@@ -73,8 +73,7 @@ smoothed_scaled <- function(epi_data,
7373
nonlin_method <- arg_match(nonlin_method)
7474

7575
epi_data <- validate_epi_data(epi_data)
76-
extra_sources <- unwrap_argument(extra_sources)
77-
trainer <- unwrap_argument(trainer)
76+
extra_sources <- unlist(extra_sources)
7877

7978
# perform any preprocessing not supported by epipredict
8079
#

R/targets/covid_forecaster_config.R

+11-11
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
#'
88
#' Variables with 'g_' prefix are globals defined in the calling script.
99
#'
10+
#' Note that expand_grid has some quirks:
11+
#' - if an entry is a vector c() or a list(), each top-level element is expanded out to a row.
12+
#' - this means that list(list()) reuses the same inner list for each row.
13+
#'
1014
#' @param dummy_mode Boolean indicating whether to use dummy forecasters
1115
#' @return A list of forecaster parameter combinations
1216
#' @export
@@ -32,7 +36,6 @@ get_covid_forecaster_params <- function() {
3236
expand_grid(
3337
forecaster = "scaled_pop",
3438
trainer = "quantreg",
35-
# since it's a list, this gets expanded out to a single one in each row
3639
extra_sources = list2("nssp", "google_symptoms", "nwss", "nwss_region", "va_covid_per_100k"),
3740
lags = list2(
3841
list2(
@@ -105,18 +108,18 @@ get_covid_forecaster_params <- function() {
105108
scaled_pop_season = tidyr::expand_grid(
106109
forecaster = "scaled_pop_seasonal",
107110
trainer = "quantreg",
108-
lags = list(
111+
lags = list2(
109112
c(0, 7, 14, 21),
110113
c(0, 7)
111114
),
112115
pop_scaling = FALSE,
113116
n_training = Inf,
114-
seasonal_method = list(
115-
c("covid"),
116-
c("window"),
117-
c("covid", "window"),
118-
c("climatological"),
119-
c("climatological", "window")
117+
seasonal_method = list2(
118+
list2("covid"),
119+
list2("window"),
120+
list2("covid", "window"),
121+
list2("climatological"),
122+
list2("climatological", "window")
120123
)
121124
),
122125
climate_linear = bind_rows(
@@ -165,9 +168,6 @@ get_covid_forecaster_params <- function() {
165168
x$forecaster <- "dummy_forecaster"
166169
}
167170
x <- add_id(x)
168-
if ("trainer" %in% names(x) && is.list(x$trainer)) {
169-
x$trainer <- x$trainer[[1]]
170-
}
171171
# Add the outcome to each forecaster.
172172
x$outcome <- "hhs"
173173
x

R/targets/flu_forecaster_config.R

+8-28
Original file line numberDiff line numberDiff line change
@@ -209,31 +209,16 @@ get_flu_forecaster_params <- function() {
209209
tidyr::expand_grid(
210210
forecaster = "scaled_pop_seasonal",
211211
trainer = "quantreg",
212-
lags = list2(
213-
c(0, 7)
212+
lags = list2(c(0, 7)),
213+
seasonal_method = list2(
214+
list2("window"),
215+
list2("window", "flu"),
216+
list2("window", "climatological")
214217
),
215-
seasonal_method = list("flu", "indicator", "climatological"),
216-
pop_scaling = FALSE,
217-
train_residual = c(TRUE, FALSE),
218-
filter_source = c("", "nhsn"),
219-
filter_agg_level = "state",
220-
drop_non_seasons = c(TRUE, FALSE),
221-
n_training = Inf,
222-
keys_to_ignore = g_very_latent_locations
223-
),
224-
# Window-based seasonal method shouldn't drop non-seasons
225-
tidyr::expand_grid(
226-
forecaster = "scaled_pop_seasonal",
227-
trainer = "quantreg",
228-
lags = list(
229-
c(0, 7)
230-
),
231-
seasonal_method = list("window", c("window", "flu"), c("window", "climatological")),
232218
pop_scaling = FALSE,
233219
train_residual = c(FALSE, TRUE),
234220
filter_source = c("", "nhsn"),
235221
filter_agg_level = "state",
236-
drop_non_seasons = FALSE,
237222
n_training = Inf,
238223
keys_to_ignore = g_very_latent_locations
239224
)
@@ -250,7 +235,7 @@ get_flu_forecaster_params <- function() {
250235
c(0, 7) # exogenous feature
251236
)
252237
),
253-
seasonal_method = "window",
238+
seasonal_method = list2("window"),
254239
pop_scaling = FALSE,
255240
filter_source = c("", "nhsn"),
256241
filter_agg_level = "state",
@@ -262,10 +247,8 @@ get_flu_forecaster_params <- function() {
262247
season_window_sizes = tidyr::expand_grid(
263248
forecaster = "scaled_pop_seasonal",
264249
trainer = "quantreg",
265-
lags = list(
266-
c(0, 7)
267-
),
268-
seasonal_method = "window",
250+
lags = list2(c(0, 7)),
251+
seasonal_method = list2("window"),
269252
pop_scaling = FALSE,
270253
train_residual = FALSE,
271254
filter_source = c("", "nhsn"),
@@ -325,9 +308,6 @@ get_flu_forecaster_params <- function() {
325308
x$forecaster <- "dummy_forecaster"
326309
}
327310
x <- add_id(x)
328-
if ("trainer" %in% names(x) && is.list(x$trainer)) {
329-
x$trainer <- x$trainer[[1]]
330-
}
331311
# Add the outcome to each forecaster.
332312
x$outcome <- "hhs"
333313
x

R/utils.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ forecaster_lookup <- function(pattern, forecaster_params_grid = NULL) {
3030

3131
out <- forecaster_params_grid %>% filter(grepl(pattern, .data$id))
3232
if (nrow(out) > 0) {
33-
out %>% glimpse()
33+
out %>% unlist()
3434
return(out)
3535
}
3636
}

scripts/covid_hosp_explore.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ g_reports_dir <- "reports"
2929
g_fetch_args <- epidatr::fetch_args_list(return_empty = FALSE, timeout_seconds = 400)
3030
# Geos with insufficient data for forecasting.
3131
g_insufficient_data_geos <- c("as", "pr", "vi", "gu", "mp")
32-
# Human-readable object to be used for inspecting the forecasters in the pipeline.
32+
# Parameters object used for grouping forecasters by family.
3333
g_forecaster_parameter_combinations <- get_covid_forecaster_params()
34-
# Targets-readable object to be used for running the pipeline.
34+
# Targets-readable object used for running the pipeline.
3535
g_forecaster_params_grid <- g_forecaster_parameter_combinations %>%
3636
imap(\(x, i) make_forecaster_grid(x, i)) %>%
3737
bind_rows()

scripts/flu_hosp_explore.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ g_very_latent_locations <- list(list(
3434
c("source"),
3535
c("flusurv", "ILI+")
3636
))
37-
# Human-readable object to be used for inspecting the forecasters in the pipeline.
37+
# Parameters object used for grouping forecasters by family.
3838
g_forecaster_parameter_combinations <- get_flu_forecaster_params()
39-
# Targets-readable object to be used for running the pipeline.
39+
# Targets-readable object used for running the pipeline.
4040
g_forecaster_params_grid <- g_forecaster_parameter_combinations %>%
4141
imap(\(x, i) make_forecaster_grid(x, i)) %>%
4242
bind_rows()

scripts/one_offs/forecaster_profiling.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ p <- profvis::profvis({
1111
epi_archive = d,
1212
outcome = "hhs",
1313
ahead = 2,
14-
extra_sources = "",
14+
extra_sources = character(),
1515
forecaster = scaled_pop,
1616
n_training_pad = 30L,
1717
forecaster_args = list(

0 commit comments

Comments
 (0)