Skip to content

Commit 5171d91

Browse files
authored
enh: forecaster parameters cleanup (#192)
* make flu forecasters parameters more consistent * don't try linreg * update template.md * fix: borked seasonal_method "window", both in data table and in forecaster, just in case * doc: update README
1 parent 70d8e35 commit 5171d91

File tree

7 files changed

+222
-240
lines changed

7 files changed

+222
-240
lines changed

R/forecasters/forecaster_scaled_pop_seasonal.R

+71-40
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,27 @@
3535
#' @importFrom zeallot %<-%
3636
#' @importFrom recipes all_numeric
3737
#' @export
38-
scaled_pop_seasonal <- function(epi_data,
39-
outcome,
40-
extra_sources = "",
41-
ahead = 1,
42-
pop_scaling = TRUE,
43-
drop_non_seasons = FALSE,
44-
scale_method = c("quantile", "std", "none"),
45-
center_method = c("median", "mean", "none"),
46-
nonlin_method = c("quart_root", "none"),
47-
seasonal_method = c("none", "flu", "covid", "indicator", "window", "climatological"),
48-
seasonal_backward_window = 5 * 7,
49-
seasonal_forward_window = 3 * 7,
50-
train_residual = FALSE,
51-
trainer = epipredict::quantile_reg(),
52-
quantile_levels = covidhub_probs(),
53-
filter_source = "",
54-
filter_agg_level = "",
55-
clip_lower = TRUE,
56-
...) {
38+
scaled_pop_seasonal <- function(
39+
epi_data,
40+
outcome,
41+
extra_sources = "",
42+
ahead = 1,
43+
pop_scaling = TRUE,
44+
drop_non_seasons = FALSE,
45+
scale_method = c("quantile", "std", "none"),
46+
center_method = c("median", "mean", "none"),
47+
nonlin_method = c("quart_root", "none"),
48+
seasonal_method = c("none", "flu", "covid", "indicator", "window", "climatological"),
49+
seasonal_backward_window = 5 * 7,
50+
seasonal_forward_window = 3 * 7,
51+
train_residual = FALSE,
52+
trainer = epipredict::quantile_reg(),
53+
quantile_levels = covidhub_probs(),
54+
filter_source = "",
55+
filter_agg_level = "",
56+
clip_lower = TRUE,
57+
...
58+
) {
5759
scale_method <- arg_match(scale_method)
5860
center_method <- arg_match(center_method)
5961
nonlin_method <- arg_match(nonlin_method)
@@ -62,6 +64,9 @@ scaled_pop_seasonal <- function(epi_data,
6264
extra_sources <- unwrap_argument(extra_sources)
6365
trainer <- unwrap_argument(trainer)
6466

67+
if (typeof(seasonal_method) == "list") {
68+
seasonal_method <- seasonal_method[[1]]
69+
}
6570
if (all(seasonal_method == c("none", "flu", "covid", "indicator", "window", "climatological"))) {
6671
seasonal_method <- "none"
6772
}
@@ -100,7 +105,8 @@ scaled_pop_seasonal <- function(epi_data,
100105
args_list <- inject(default_args_list(!!!args_input))
101106
# if you want to hardcode particular predictors in a particular forecaster
102107
predictors <- c(outcome, extra_sources)
103-
c(args_list, predictors, trainer) %<-% sanitize_args_predictors_trainer(epi_data, outcome, predictors, trainer, args_list)
108+
c(args_list, predictors, trainer) %<-%
109+
sanitize_args_predictors_trainer(epi_data, outcome, predictors, trainer, args_list)
104110

105111
if ("season_week" %nin% names(epi_data)) {
106112
epi_data %<>% add_season_info()
@@ -116,13 +122,27 @@ scaled_pop_seasonal <- function(epi_data,
116122
season_data <- epi_data
117123
}
118124
# TODO: Jank way to avoid having hhs_region get centered; this isn't very general
119-
learned_params <- calculate_whitening_params(season_data, setdiff(predictors, "hhs_region"), scale_method, center_method, nonlin_method)
125+
learned_params <- calculate_whitening_params(
126+
season_data,
127+
setdiff(predictors, "hhs_region"),
128+
scale_method,
129+
center_method,
130+
nonlin_method
131+
)
120132
epi_data %<>% data_whitening(setdiff(predictors, "hhs_region"), learned_params, nonlin_method)
121133

122134
# get the seasonal features
123135
# first add PCA
124136
if (("flu" %in% seasonal_method) || ("covid" %in% seasonal_method)) {
125-
epi_data <- compute_pca(epi_data, seasonal_method, ahead, scale_method, center_method, nonlin_method, normalize = train_residual)
137+
epi_data <- compute_pca(
138+
epi_data,
139+
seasonal_method,
140+
ahead,
141+
scale_method,
142+
center_method,
143+
nonlin_method,
144+
normalize = train_residual
145+
)
126146

127147
if (train_residual) {
128148
epi_data <- epi_data %>% mutate(across(all_of(outcome), ~ .x - PC1))
@@ -172,14 +192,15 @@ scaled_pop_seasonal <- function(epi_data,
172192
# preprocessing supported by epipredict
173193
preproc <- epi_recipe(epi_data)
174194
if (pop_scaling) {
175-
preproc %<>% step_population_scaling(
176-
all_of(predictors),
177-
df = epidatasets::state_census,
178-
df_pop_col = "pop",
179-
create_new = FALSE,
180-
rate_rescaling = 1e5,
181-
by = c("geo_value" = "abbr")
182-
)
195+
preproc %<>%
196+
step_population_scaling(
197+
all_of(predictors),
198+
df = epidatasets::state_census,
199+
df_pop_col = "pop",
200+
create_new = FALSE,
201+
rate_rescaling = 1e5,
202+
by = c("geo_value" = "abbr")
203+
)
183204
}
184205
if ("indicator" %in% seasonal_method) {
185206
preproc %<>%
@@ -201,14 +222,16 @@ scaled_pop_seasonal <- function(epi_data,
201222
postproc <- frosting()
202223
postproc %<>% arx_postprocess(trainer, args_list)
203224
if (pop_scaling) {
204-
postproc %<>% layer_population_scaling(
205-
.pred, .pred_distn,
206-
df = epidatasets::state_census,
207-
df_pop_col = "pop",
208-
create_new = FALSE,
209-
rate_rescaling = 1e5,
210-
by = c("geo_value" = "abbr")
211-
)
225+
postproc %<>%
226+
layer_population_scaling(
227+
.pred,
228+
.pred_distn,
229+
df = epidatasets::state_census,
230+
df_pop_col = "pop",
231+
create_new = FALSE,
232+
rate_rescaling = 1e5,
233+
by = c("geo_value" = "abbr")
234+
)
212235
}
213236
# with all the setup done, we execute and format
214237
pred <- run_workflow_and_format(preproc, postproc, trainer, season_data, epi_data)
@@ -217,7 +240,10 @@ scaled_pop_seasonal <- function(epi_data,
217240
# finally, any postprocessing not supported by epipredict e.g. calibration
218241
#
219242
# undo subtraction if we're training on residuals
220-
if (train_residual && (("flu" %in% seasonal_method) || ("covid" %in% seasonal_method) || ("climatological" %in% seasonal_method))) {
243+
if (
244+
train_residual &&
245+
(("flu" %in% seasonal_method) || ("covid" %in% seasonal_method) || ("climatological" %in% seasonal_method))
246+
) {
221247
pred <- pred %>%
222248
mutate(epi_week = epiweek(target_end_date)) %>%
223249
left_join(values_subtracted, by = join_by(geo_value, source, epi_week == epiweek)) %>%
@@ -228,7 +254,12 @@ scaled_pop_seasonal <- function(epi_data,
228254
# reintroduce color into the value
229255
pred_final <- pred %>%
230256
rename({{ outcome }} := value) %>%
231-
data_coloring(outcome, learned_params, join_cols = key_colnames(epi_data, exclude = "time_value"), nonlin_method = nonlin_method) %>%
257+
data_coloring(
258+
outcome,
259+
learned_params,
260+
join_cols = key_colnames(epi_data, exclude = "time_value"),
261+
nonlin_method = nonlin_method
262+
) %>%
232263
rename(value = {{ outcome }})
233264
if (clip_lower) {
234265
pred_final %<>% mutate(value = pmax(0, value))

R/new_epipredict_steps/step_training_window.R

+22-12
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#' Expects n_recent to be finite.
1616
#' @param seasonal_forward_window An integer value that represents the number of days
1717
#' after a season week to include in the training window. The default value
18-
#' is 14. Only valid when seasonal is TRUE.
18+
#' is 21. Only valid when seasonal is TRUE.
1919
#' @param seasonal_backward_window An integer value that represents the number of days
2020
#' before a season week to include in the training window. The default value
2121
#' is 35. Only valid when seasonal is TRUE.
@@ -50,14 +50,16 @@
5050
#' prep(tib) %>%
5151
#' bake(new_data = NULL)
5252
step_epi_training_window <-
53-
function(recipe,
54-
role = NA,
55-
n_recent = 50,
56-
seasonal = FALSE,
57-
seasonal_forward_window = 14,
58-
seasonal_backward_window = 35,
59-
epi_keys = NULL,
60-
id = rand_id("epi_training_window")) {
53+
function(
54+
recipe,
55+
role = NA,
56+
n_recent = 50,
57+
seasonal = FALSE,
58+
seasonal_forward_window = 21,
59+
seasonal_backward_window = 35,
60+
epi_keys = NULL,
61+
id = rand_id("epi_training_window")
62+
) {
6163
epipredict:::arg_is_scalar(n_recent, id, seasonal, seasonal_forward_window, seasonal_backward_window)
6264
epipredict:::arg_is_pos(n_recent, seasonal_forward_window, seasonal_backward_window)
6365
if (is.finite(n_recent)) epipredict:::arg_is_pos_int(n_recent)
@@ -150,7 +152,6 @@ bake.step_epi_training_window <- function(object, new_data, ...) {
150152
new_data %<>% filter(time_value %in% date_ranges)
151153
}
152154

153-
154155
new_data
155156
}
156157

@@ -162,8 +163,17 @@ print.step_epi_training_window <-
162163
n_recent <- x$n_recent
163164
seasonal_forward_window <- x$seasonal_forward_window
164165
seasonal_backward_window <- x$seasonal_backward_window
165-
tr_obj <- recipes::format_selectors(rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window), width)
166-
recipes::print_step(tr_obj, rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window), x$trained, title, width)
166+
tr_obj <- recipes::format_selectors(
167+
rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window),
168+
width
169+
)
170+
recipes::print_step(
171+
tr_obj,
172+
rlang::enquos(n_recent, seasonal_forward_window, seasonal_backward_window),
173+
x$trained,
174+
title,
175+
width
176+
)
167177
} else {
168178
title <- "# of recent observations per key limited to:"
169179
n_recent <- x$n_recent

R/targets/covid_forecaster_config.R

+1-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ get_covid_forecaster_params <- function() {
1414
out <- rlang::list2(
1515
scaled_pop_main = tidyr::expand_grid(
1616
forecaster = "scaled_pop",
17-
trainer = list("linreg", "quantreg"),
17+
trainer = "quantreg",
1818
lags = list(
1919
c(0, 7),
2020
c(0, 7, 14),
@@ -167,9 +167,6 @@ get_covid_forecaster_params <- function() {
167167
if ("trainer" %in% names(x) && is.list(x$trainer)) {
168168
x$trainer <- x$trainer[[1]]
169169
}
170-
if ("seasonal_method" %in% names(x) && is.list(x$seasonal_method)) {
171-
x$seasonal_method <- x$seasonal_method[[1]]
172-
}
173170
# Add the outcome to each forecaster.
174171
x$outcome <- "hhs"
175172
x

0 commit comments

Comments
 (0)