|
| 1 | +# various reusable transforms to apply before handing to epipredict |
| 2 | + |
| 3 | +#' extract the non-key, non-smoothed columns from epi_data |
| 4 | +#' @keywords internal |
| 5 | +#' @param epi_data the `epi_df` |
| 6 | +#' @param cols vector of column names to use. If `NULL`, fill with all non-key columns |
| 7 | +get_trainable_names <- function(epi_data, cols) { |
| 8 | + if (is.null(cols)) { |
| 9 | + cols <- get_nonkey_names(epi_data) |
| 10 | + # exclude anything with the same naming schema as the rolling average/sd created below |
| 11 | + cols <- cols[!grepl("_\\w{1,2}\\d+", cols)] |
| 12 | + } |
| 13 | + return(cols) |
| 14 | +} |
| 15 | + |
| 16 | +#' just the names which aren't keys for an epi_df |
| 17 | +#' @description |
| 18 | +#' names, but it excludes keys |
| 19 | +#' @param epi_data the epi_df |
| 20 | +get_nonkey_names <- function(epi_data) { |
| 21 | + cols <- names(epi_data) |
| 22 | + cols <- cols[!(cols %in% c("geo_value", "time_value", attr(epi_data, "metadata")$other_keys))] |
| 23 | + return(cols) |
| 24 | +} |
| 25 | + |
| 26 | + |
| 27 | +#' update the predictors to only contain the smoothed/sd versions of cols |
| 28 | +#' @description |
| 29 | +#' modifies the list of preditors so that any which have been modified have the |
| 30 | +#' modified versions included, and not the original. Should only be applied |
| 31 | +#' after both rolling_mean and rolling_sd. |
| 32 | +#' @param epi_data the epi_df, only included to get the non-key column names |
| 33 | +#' @param cols_modified the list of columns which have been modified. If this is `NULL`, that means we were modifying every column. |
| 34 | +#' @param predictors the initial set of predictors; any unmodified are kept, any modified are replaced with the modified versions (e.g. "a" becoming "a_m17"). |
| 35 | +#' @importFrom purrr map map_chr reduce |
| 36 | +#' @return returns an updated list of predictors, with modified columns replaced and non-modified columns left intact. |
| 37 | +#' @export |
| 38 | +update_predictors <- function(epi_data, cols_modified, predictors) { |
| 39 | + if (!is.null(cols_modified)) { |
| 40 | + # if cols_modified isn't null, make sure we include predictors that weren't modified |
| 41 | + unchanged_predictors <- map(cols_modified, ~ !grepl(.x, predictors, fixed = TRUE)) %>% reduce(`&`) |
| 42 | + unchanged_predictors <- predictors[unchanged_predictors] |
| 43 | + } else { |
| 44 | + # if it's null, we've modified every predictor |
| 45 | + unchanged_predictors <- character(0L) |
| 46 | + } |
| 47 | + # all the non-key names |
| 48 | + col_names <- get_nonkey_names(epi_data) |
| 49 | + is_present <- function(original_predictor) { |
| 50 | + grepl(original_predictor, col_names) & !(col_names %in% predictors) |
| 51 | + } |
| 52 | + is_modified <- map(predictors, is_present) %>% reduce(`|`) |
| 53 | + new_predictors <- col_names[is_modified] |
| 54 | + return(c(unchanged_predictors, new_predictors)) |
| 55 | +} |
| 56 | + |
| 57 | +#' get a rolling average for the named columns |
| 58 | +#' @description |
| 59 | +#' add column(s) that are the rolling means of the specified columns, as |
| 60 | +#' implemented by slider. Defaults to the previous 7 days. |
| 61 | +#' Currently only group_by's on the geo_value. Should probably extend to more |
| 62 | +#' keys if you have them |
| 63 | +#' @param epi_data the dataset |
| 64 | +#' @param width the number of days (or examples, the sliding isn't time-aware) to use |
| 65 | +#' @param cols_to_mean the non-key columns to take the mean over. `NULL` means all |
| 66 | +#' @importFrom slider slide_dbl |
| 67 | +#' @importFrom epiprocess epi_slide |
| 68 | +#' @export |
| 69 | +rolling_mean <- function(epi_data, width = 7L, cols_to_mean = NULL) { |
| 70 | + cols_to_mean <- get_trainable_names(epi_data, cols_to_mean) |
| 71 | + epi_data %<>% group_by(geo_value) |
| 72 | + for (col in cols_to_mean) { |
| 73 | + mean_name <- paste0(col, "_m", width) |
| 74 | + epi_data %<>% epi_slide(~ mean(.x[[col]], rm.na = TRUE), before = width-1L, new_col_name = mean_name) |
| 75 | + } |
| 76 | + epi_data %<>% ungroup() |
| 77 | + return(epi_data) |
| 78 | +} |
| 79 | + |
| 80 | +#' get a rolling standard deviation for the named columns |
| 81 | +#' @description |
| 82 | +#' A rolling standard deviation, based off of a rolling mean. First it |
| 83 | +#' calculates a rolling mean with width `mean_width`, and then squares the |
| 84 | +#' difference between that and the actual value, averaged over `sd_width`. |
| 85 | +#' @param epi_data the dataset |
| 86 | +#' @param sd_width the number of days (or examples, the sliding isn't |
| 87 | +#' time-aware) to use for the standard deviation calculation |
| 88 | +#' @param mean_width like `sd_width`, but it governs the mean. Should be less |
| 89 | +#' than the `sd_width`, and if `NULL` (the default) it is half of `sd_width` |
| 90 | +#' (so 14 in the complete default case) |
| 91 | +#' @param cols_to_sd the non-key columns to take the sd over. `NULL` means all |
| 92 | +#' @param keep_mean bool, if `TRUE`, it retains keeps the mean column |
| 93 | +#' @importFrom epiprocess epi_slide |
| 94 | +#' @export |
| 95 | +rolling_sd <- function(epi_data, sd_width = 28L, mean_width = NULL, cols_to_sd = NULL, keep_mean = FALSE) { |
| 96 | + if (is.null(mean_width)) { |
| 97 | + mean_width <- as.integer(ceiling(sd_width / 2)) |
| 98 | + } |
| 99 | + cols_to_sd <- get_trainable_names(epi_data, cols_to_sd) |
| 100 | + result <- epi_data |
| 101 | + for (col in cols_to_sd) { |
| 102 | + result %<>% group_by(geo_value) |
| 103 | + mean_name <- paste0(col, "_m", mean_width) |
| 104 | + sd_name <- paste0(col, "_sd", sd_width) |
| 105 | + result %<>% epi_slide(~ mean(.x[[col]], na.rm = TRUE), before = mean_width-1L, new_col_name = mean_name) |
| 106 | + result %<>% epi_slide(~ sqrt(mean((.x[[mean_name]] - .x[[col]])^2, na.rm = TRUE)), before = sd_width-1, new_col_name = sd_name) |
| 107 | + if (!keep_mean) { |
| 108 | + # TODO make sure the extra info sticks around |
| 109 | + result %<>% select(-{{ mean_name }}) |
| 110 | + } |
| 111 | + result %<>% dplyr_reconstruct(epi_data) |
| 112 | + } |
| 113 | + result %<>% ungroup() |
| 114 | +} |
0 commit comments