Skip to content

Commit 993b571

Browse files
authored
Merge pull request #76 from cmu-delphi/smoothedScaled
Add the 7dav we talked about along with the std
2 parents a4f5fd2 + 15d35d7 commit 993b571

25 files changed

+623
-40
lines changed

DESCRIPTION

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ Imports:
2727
purrr,
2828
recipes (>= 1.0.4),
2929
rlang,
30+
slider,
3031
targets,
3132
tibble,
32-
tidyr
33+
tidyr,
34+
zeallot
3335
Suggests:
3436
ggplot2,
3537
knitr,

NAMESPACE

+10-1
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,19 @@ export(make_target_ensemble_grid)
2929
export(make_target_param_grid)
3030
export(manage_S3_forecast_cache)
3131
export(overprediction)
32-
export(perform_sanity_checks)
3332
export(read_external_predictions_data)
33+
export(rolling_mean)
34+
export(rolling_sd)
3435
export(run_evaluation_measure)
3536
export(run_workflow_and_format)
37+
export(sanitize_args_predictors_trainer)
3638
export(scaled_pop)
3739
export(sharpness)
3840
export(single_id)
3941
export(slide_forecaster)
42+
export(smoothed_scaled)
4043
export(underprediction)
44+
export(update_predictors)
4145
export(weighted_interval_score)
4246
importFrom(assertthat,assert_that)
4347
importFrom(aws.s3,get_bucket)
@@ -84,14 +88,17 @@ importFrom(epipredict,step_epi_naomit)
8488
importFrom(epipredict,step_population_scaling)
8589
importFrom(epipredict,step_training_window)
8690
importFrom(epiprocess,as_epi_df)
91+
importFrom(epiprocess,epi_slide)
8792
importFrom(epiprocess,epix_slide)
8893
importFrom(here,here)
8994
importFrom(magrittr,"%<>%")
9095
importFrom(magrittr,"%>%")
9196
importFrom(purrr,imap)
9297
importFrom(purrr,map)
9398
importFrom(purrr,map2_vec)
99+
importFrom(purrr,map_chr)
94100
importFrom(purrr,map_vec)
101+
importFrom(purrr,reduce)
95102
importFrom(purrr,transpose)
96103
importFrom(recipes,all_numeric)
97104
importFrom(rlang,"!!")
@@ -100,6 +107,7 @@ importFrom(rlang,.data)
100107
importFrom(rlang,quo)
101108
importFrom(rlang,sym)
102109
importFrom(rlang,syms)
110+
importFrom(slider,slide_dbl)
103111
importFrom(targets,tar_config_get)
104112
importFrom(targets,tar_group)
105113
importFrom(targets,tar_read)
@@ -109,3 +117,4 @@ importFrom(tidyr,drop_na)
109117
importFrom(tidyr,expand_grid)
110118
importFrom(tidyr,pivot_wider)
111119
importFrom(tidyr,unnest)
120+
importFrom(zeallot,"%<-%")

R/data_transforms.R

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# various reusable transforms to apply before handing to epipredict
2+
3+
#' extract the non-key, non-smoothed columns from epi_data
4+
#' @keywords internal
5+
#' @param epi_data the `epi_df`
6+
#' @param cols vector of column names to use. If `NULL`, fill with all non-key columns
7+
get_trainable_names <- function(epi_data, cols) {
8+
if (is.null(cols)) {
9+
cols <- get_nonkey_names(epi_data)
10+
# exclude anything with the same naming schema as the rolling average/sd created below
11+
cols <- cols[!grepl("_\\w{1,2}\\d+", cols)]
12+
}
13+
return(cols)
14+
}
15+
16+
#' just the names which aren't keys for an epi_df
17+
#' @description
18+
#' names, but it excludes keys
19+
#' @param epi_data the epi_df
20+
get_nonkey_names <- function(epi_data) {
21+
cols <- names(epi_data)
22+
cols <- cols[!(cols %in% c("geo_value", "time_value", attr(epi_data, "metadata")$other_keys))]
23+
return(cols)
24+
}
25+
26+
27+
#' update the predictors to only contain the smoothed/sd versions of cols
28+
#' @description
29+
#' modifies the list of preditors so that any which have been modified have the
30+
#' modified versions included, and not the original. Should only be applied
31+
#' after both rolling_mean and rolling_sd.
32+
#' @param epi_data the epi_df, only included to get the non-key column names
33+
#' @param cols_modified the list of columns which have been modified. If this is `NULL`, that means we were modifying every column.
34+
#' @param predictors the initial set of predictors; any unmodified are kept, any modified are replaced with the modified versions (e.g. "a" becoming "a_m17").
35+
#' @importFrom purrr map map_chr reduce
36+
#' @return returns an updated list of predictors, with modified columns replaced and non-modified columns left intact.
37+
#' @export
38+
update_predictors <- function(epi_data, cols_modified, predictors) {
39+
if (!is.null(cols_modified)) {
40+
# if cols_modified isn't null, make sure we include predictors that weren't modified
41+
unchanged_predictors <- map(cols_modified, ~ !grepl(.x, predictors, fixed = TRUE)) %>% reduce(`&`)
42+
unchanged_predictors <- predictors[unchanged_predictors]
43+
} else {
44+
# if it's null, we've modified every predictor
45+
unchanged_predictors <- character(0L)
46+
}
47+
# all the non-key names
48+
col_names <- get_nonkey_names(epi_data)
49+
is_present <- function(original_predictor) {
50+
grepl(original_predictor, col_names) & !(col_names %in% predictors)
51+
}
52+
is_modified <- map(predictors, is_present) %>% reduce(`|`)
53+
new_predictors <- col_names[is_modified]
54+
return(c(unchanged_predictors, new_predictors))
55+
}
56+
57+
#' get a rolling average for the named columns
58+
#' @description
59+
#' add column(s) that are the rolling means of the specified columns, as
60+
#' implemented by slider. Defaults to the previous 7 days.
61+
#' Currently only group_by's on the geo_value. Should probably extend to more
62+
#' keys if you have them
63+
#' @param epi_data the dataset
64+
#' @param width the number of days (or examples, the sliding isn't time-aware) to use
65+
#' @param cols_to_mean the non-key columns to take the mean over. `NULL` means all
66+
#' @importFrom slider slide_dbl
67+
#' @importFrom epiprocess epi_slide
68+
#' @export
69+
rolling_mean <- function(epi_data, width = 7L, cols_to_mean = NULL) {
70+
cols_to_mean <- get_trainable_names(epi_data, cols_to_mean)
71+
epi_data %<>% group_by(geo_value)
72+
for (col in cols_to_mean) {
73+
mean_name <- paste0(col, "_m", width)
74+
epi_data %<>% epi_slide(~ mean(.x[[col]], rm.na = TRUE), before = width-1L, new_col_name = mean_name)
75+
}
76+
epi_data %<>% ungroup()
77+
return(epi_data)
78+
}
79+
80+
#' get a rolling standard deviation for the named columns
81+
#' @description
82+
#' A rolling standard deviation, based off of a rolling mean. First it
83+
#' calculates a rolling mean with width `mean_width`, and then squares the
84+
#' difference between that and the actual value, averaged over `sd_width`.
85+
#' @param epi_data the dataset
86+
#' @param sd_width the number of days (or examples, the sliding isn't
87+
#' time-aware) to use for the standard deviation calculation
88+
#' @param mean_width like `sd_width`, but it governs the mean. Should be less
89+
#' than the `sd_width`, and if `NULL` (the default) it is half of `sd_width`
90+
#' (so 14 in the complete default case)
91+
#' @param cols_to_sd the non-key columns to take the sd over. `NULL` means all
92+
#' @param keep_mean bool, if `TRUE`, it retains keeps the mean column
93+
#' @importFrom epiprocess epi_slide
94+
#' @export
95+
rolling_sd <- function(epi_data, sd_width = 28L, mean_width = NULL, cols_to_sd = NULL, keep_mean = FALSE) {
96+
if (is.null(mean_width)) {
97+
mean_width <- as.integer(ceiling(sd_width / 2))
98+
}
99+
cols_to_sd <- get_trainable_names(epi_data, cols_to_sd)
100+
result <- epi_data
101+
for (col in cols_to_sd) {
102+
result %<>% group_by(geo_value)
103+
mean_name <- paste0(col, "_m", mean_width)
104+
sd_name <- paste0(col, "_sd", sd_width)
105+
result %<>% epi_slide(~ mean(.x[[col]], na.rm = TRUE), before = mean_width-1L, new_col_name = mean_name)
106+
result %<>% epi_slide(~ sqrt(mean((.x[[mean_name]] - .x[[col]])^2, na.rm = TRUE)), before = sd_width-1, new_col_name = sd_name)
107+
if (!keep_mean) {
108+
# TODO make sure the extra info sticks around
109+
result %<>% select(-{{ mean_name }})
110+
}
111+
result %<>% dplyr_reconstruct(epi_data)
112+
}
113+
result %<>% ungroup()
114+
}

R/data_validation.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#' include empty strings
1414
#' @param args_list the args list created by [`epipredict::arx_args_list`]
1515
#' @export
16-
perform_sanity_checks <- function(epi_data,
16+
sanitize_args_predictors_trainer <- function(epi_data,
1717
outcome,
1818
predictors,
1919
trainer,
@@ -56,7 +56,7 @@ perform_sanity_checks <- function(epi_data,
5656
#' @export
5757
confirm_sufficient_data <- function(epi_data, ahead, args_input, buffer = 9) {
5858
if (!is.null(args_input$lags)) {
59-
lag_max <- max(args_input$lags)
59+
lag_max <- max(unlist(args_input$lags))
6060
} else {
6161
lag_max <- 14 # default value of 2 weeks
6262
}

R/epipredict_utilities.R

+5-5
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,25 @@
22
#' add the default steps for arx_forecaster
33
#' @description
44
#' add the default steps for arx_forecaster
5-
#' @param rec an [`epipredict::epi_recipe`]
5+
#' @param preproc an [`epipredict::epi_recipe`]
66
#' @param outcome a character of the column to be predicted
77
#' @param predictors a character vector of the columns used as predictors
88
#' @param args_list an [`epipredict::arx_args_list`]
99
#' @seealso [arx_postprocess] for the layer equivalent
1010
#' @importFrom epipredict step_epi_lag step_epi_ahead step_epi_naomit step_training_window
1111
#' @export
12-
arx_preprocess <- function(rec, outcome, predictors, args_list) {
12+
arx_preprocess <- function(preproc, outcome, predictors, args_list) {
1313
# input already validated
1414
lags <- args_list$lags
1515
for (l in seq_along(lags)) {
1616
p <- predictors[l]
17-
rec %<>% step_epi_lag(!!p, lag = lags[[l]])
17+
preproc %<>% step_epi_lag(!!p, lag = lags[[l]])
1818
}
19-
rec %<>%
19+
preproc %<>%
2020
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
2121
step_epi_naomit() %>%
2222
step_training_window(n_recent = args_list$n_training)
23-
return(rec)
23+
return(preproc)
2424
}
2525

2626
# TODO replace with `layer_arx_forecaster`

R/forecaster_flatline.R

+1-3
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ flatline_fc <- function(epi_data,
4040
args_list <- do.call(flatline_args_list, args_input)
4141
# if you want to ignore extra_sources, setting predictors is the way to do it
4242
predictors <- c(outcome, extra_sources)
43-
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, NULL, args_list)
44-
args_list <- argsPredictorsTrainer[[1]]
45-
predictors <- argsPredictorsTrainer[[2]]
43+
c(args_list, predictors, trainer) %<-% sanitize_args_predictors_trainer(epi_data, outcome, predictors, NULL, args_list)
4644
# end of the copypasta
4745
# finally, any other pre-processing (e.g. smoothing) that isn't performed by
4846
# epipredict

R/forecaster_scaled_pop.R

+4-6
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
#' @param quantile_levels The quantile levels to predict. Defaults to those required by
3636
#' covidhub.
3737
#' @seealso some utilities for making forecasters: [format_storage],
38-
#' [perform_sanity_checks]
38+
#' [sanitize_args_predictors_trainer]
3939
#' @importFrom epipredict epi_recipe step_population_scaling frosting arx_args_list layer_population_scaling
4040
#' @importFrom tibble tibble
41+
#' @importFrom zeallot %<-%
4142
#' @importFrom recipes all_numeric
4243
#' @export
4344
scaled_pop <- function(epi_data,
@@ -73,13 +74,10 @@ scaled_pop <- function(epi_data,
7374
args_input[["ahead"]] <- effective_ahead
7475
args_input[["quantile_levels"]] <- quantile_levels
7576
args_list <- do.call(arx_args_list, args_input)
76-
# if you want to ignore extra_sources, setting predictors is the way to do it
77+
# if you want to hardcode particular predictors in a particular forecaster
7778
predictors <- c(outcome, extra_sources)
7879
# TODO: Partial match quantile_level coming from here (on Dmitry's machine)
79-
argsPredictorsTrainer <- perform_sanity_checks(epi_data, outcome, predictors, trainer, args_list)
80-
args_list <- argsPredictorsTrainer[[1]]
81-
predictors <- argsPredictorsTrainer[[2]]
82-
trainer <- argsPredictorsTrainer[[3]]
80+
c(args_list, predictors, trainer) %<-% sanitize_args_predictors_trainer(epi_data, outcome, predictors, trainer, args_list)
8381
# end of the copypasta
8482
# finally, any other pre-processing (e.g. smoothing) that isn't performed by
8583
# epipredict

0 commit comments

Comments
 (0)