Skip to content

Commit 7135d7b

Browse files
committed
Merge branch 'dev' into grf-arx-hotfix
2 parents 109113b + ea34700 commit 7135d7b

18 files changed

+541
-67
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: epipredict
22
Title: Basic epidemiology forecasting methods
3-
Version: 0.1.3
3+
Version: 0.1.4
44
Authors@R: c(
55
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
66
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ importFrom(dplyr,filter)
242242
importFrom(dplyr,full_join)
243243
importFrom(dplyr,group_by)
244244
importFrom(dplyr,group_by_at)
245+
importFrom(dplyr,inner_join)
245246
importFrom(dplyr,join_by)
246247
importFrom(dplyr,left_join)
247248
importFrom(dplyr,mutate)
@@ -273,6 +274,7 @@ importFrom(hardhat,extract_recipe)
273274
importFrom(hardhat,refresh_blueprint)
274275
importFrom(hardhat,run_mold)
275276
importFrom(magrittr,"%>%")
277+
importFrom(magrittr,extract2)
276278
importFrom(recipes,bake)
277279
importFrom(recipes,detect_step)
278280
importFrom(recipes,prep)

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
1515
## Improvements
1616

1717
- Add `step_adjust_latency`, which give several methods to adjust the forecast if the `forecast_date` is after the last day of data.
18+
- Fix `layer_population_scaling` default `by` with `other_keys`.
19+
- Make key column inference more consistent within the package and with current `epiprocess`.
20+
- Fix `quantile_reg()` producing error when asked to output just median-level predictions.
1821
- (temporary) ahead negative is allowed for `step_epi_ahead` until we have `step_epi_shift`
1922

2023
## Bug fixes

R/autoplot.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,10 @@ autoplot.epi_workflow <- function(
127127
if (!is.null(shift)) {
128128
edf <- mutate(edf, time_value = time_value + shift)
129129
}
130-
extra_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
131-
if (length(extra_keys) == 0L) extra_keys <- NULL
130+
other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
132131
edf <- as_epi_df(edf,
133132
as_of = object$fit$meta$as_of,
134-
other_keys = extra_keys %||% character()
133+
other_keys = other_keys
135134
)
136135
if (is.null(predictions)) {
137136
return(autoplot(

R/epipredict-package.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
#' @importFrom cli cli_abort cli_warn
88
#' @importFrom dplyr arrange across all_of any_of bind_cols bind_rows group_by
99
#' @importFrom dplyr full_join relocate summarise everything
10+
#' @importFrom dplyr inner_join
1011
#' @importFrom dplyr summarize filter mutate select left_join rename ungroup
12+
#' @importFrom magrittr extract2
1113
#' @importFrom rlang := !! %||% as_function global_env set_names !!! caller_arg
1214
#' @importFrom rlang is_logical is_true inject enquo enquos expr sym arg_match
1315
#' @importFrom stats poly predict lm residuals quantile

R/key_colnames.R

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
#' @export
2-
key_colnames.recipe <- function(x, ...) {
2+
key_colnames.recipe <- function(x, ..., exclude = character()) {
33
geo_key <- x$var_info$variable[x$var_info$role %in% "geo_value"]
44
time_key <- x$var_info$variable[x$var_info$role %in% "time_value"]
55
keys <- x$var_info$variable[x$var_info$role %in% "key"]
6-
c(geo_key, keys, time_key) %||% character(0L)
6+
full_key <- c(geo_key, keys, time_key) %||% character(0L)
7+
full_key[!full_key %in% exclude]
78
}
89

910
#' @export
10-
key_colnames.epi_workflow <- function(x, ...) {
11+
key_colnames.epi_workflow <- function(x, ..., exclude = character()) {
1112
# safer to look at the mold than the preprocessor
1213
mold <- hardhat::extract_mold(x)
13-
molded_names <- names(mold$extras$roles)
14-
geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value)
15-
time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value)
16-
keys <- names(mold$extras$roles[molded_names %in% "key"]$key)
17-
c(geo_key, keys, time_key) %||% character(0L)
14+
molded_roles <- mold$extras$roles
15+
extras <- bind_cols(molded_roles$geo_value, molded_roles$key, molded_roles$time_value)
16+
full_key <- names(extras)
17+
if (length(full_key) == 0L) {
18+
# No epikeytime role assignment; infer from all columns:
19+
potential_keys <- c("geo_value", "time_value")
20+
full_key <- potential_keys[potential_keys %in% names(bind_cols(molded_roles))]
21+
}
22+
full_key[!full_key %in% exclude]
1823
}
1924

2025
kill_time_value <- function(v) {

R/layer_population_scaling.R

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919
#' inverting the existing scaling.
2020
#' @param by A (possibly named) character vector of variables to join by.
2121
#'
22-
#' If `NULL`, the default, the function will perform a natural join, using all
23-
#' variables in common across the `epi_df` produced by the `predict()` call
24-
#' and the user-provided dataset.
25-
#' If columns in that `epi_df` and `df` have the same name (and aren't
26-
#' included in `by`), `.df` is added to the one from the user-provided data
27-
#' to disambiguate.
22+
#' If `NULL`, the default, the function will try to infer a reasonable set of
23+
#' columns. First, it will try to join by all variables in the test data with
24+
#' roles `"geo_value"`, `"key"`, or `"time_value"` that also appear in `df`;
25+
#' these roles are automatically set if you are using an `epi_df`, or you can
26+
#' use, e.g., `update_role`. If no such roles are set, it will try to perform a
27+
#' natural join, using variables in common between the training/test data and
28+
#' population data.
29+
#'
30+
#' If columns in the training/testing data and `df` have the same name (and
31+
#' aren't included in `by`), a `.df` suffix is added to the one from the
32+
#' user-provided data to disambiguate.
2833
#'
2934
#' To join by different variables on the `epi_df` and `df`, use a named vector.
3035
#' For example, `by = c("geo_value" = "states")` will match `epi_df$geo_value`
@@ -135,6 +140,26 @@ slather.layer_population_scaling <-
135140
)
136141
rlang::check_dots_empty()
137142

143+
if (is.null(object$by)) {
144+
# Assume `layer_predict` has calculated the prediction keys and other
145+
# layers don't change the prediction key colnames:
146+
prediction_key_colnames <- names(components$keys)
147+
lhs_potential_keys <- prediction_key_colnames
148+
rhs_potential_keys <- colnames(select(object$df, !object$df_pop_col))
149+
object$by <- intersect(lhs_potential_keys, rhs_potential_keys)
150+
suggested_min_keys <- kill_time_value(lhs_potential_keys)
151+
if (!all(suggested_min_keys %in% object$by)) {
152+
cli_warn(c(
153+
"{setdiff(suggested_min_keys, object$by)} {?was an/were} epikey column{?s} in the predictions,
154+
but {?wasn't/weren't} found in the population `df`.",
155+
"i" = "Defaulting to join by {object$by}",
156+
">" = "Double-check whether column names on the population `df` match those expected in your predictions",
157+
">" = "Consider using population data with breakdowns by {suggested_min_keys}",
158+
">" = "Manually specify `by =` to silence"
159+
), class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys")
160+
}
161+
}
162+
138163
object$by <- object$by %||% intersect(
139164
epi_keys_only(components$predictions),
140165
colnames(select(object$df, !object$df_pop_col))
@@ -152,10 +177,12 @@ slather.layer_population_scaling <-
152177
suffix <- ifelse(object$create_new, object$suffix, "")
153178
col_to_remove <- setdiff(colnames(object$df), colnames(components$predictions))
154179

155-
components$predictions <- left_join(
180+
components$predictions <- inner_join(
156181
components$predictions,
157182
object$df,
158183
by = object$by,
184+
relationship = "many-to-one",
185+
unmatched = c("error", "drop"),
159186
suffix = c("", ".df")
160187
) %>%
161188
mutate(across(

R/make_quantile_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ make_quantile_reg <- function() {
112112

113113
# can't make a method because object is second
114114
out <- switch(type,
115-
rq = dist_quantiles(unname(as.list(x)), object$quantile_levels), # one quantile
115+
rq = dist_quantiles(unname(as.list(x)), object$tau), # one quantile
116116
rqs = {
117117
x <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
118118
dist_quantiles(x, list(object$tau))

R/step_population_scaling.R

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,25 @@
1616
#' inverting the existing scaling.
1717
#' @param by A (possibly named) character vector of variables to join by.
1818
#'
19-
#' If `NULL`, the default, the function will perform a natural join, using all
20-
#' variables in common across the `epi_df` produced by the `predict()` call
21-
#' and the user-provided dataset.
22-
#' If columns in that `epi_df` and `df` have the same name (and aren't
23-
#' included in `by`), `.df` is added to the one from the user-provided data
24-
#' to disambiguate.
19+
#' If `NULL`, the default, the function will try to infer a reasonable set of
20+
#' columns. First, it will try to join by all variables in the training/test
21+
#' data with roles `"geo_value"`, `"key"`, or `"time_value"` that also appear in
22+
#' `df`; these roles are automatically set if you are using an `epi_df`, or you
23+
#' can use, e.g., `update_role`. If no such roles are set, it will try to
24+
#' perform a natural join, using variables in common between the training/test
25+
#' data and population data.
26+
#'
27+
#' If columns in the training/testing data and `df` have the same name (and
28+
#' aren't included in `by`), a `.df` suffix is added to the one from the
29+
#' user-provided data to disambiguate.
2530
#'
2631
#' To join by different variables on the `epi_df` and `df`, use a named vector.
2732
#' For example, `by = c("geo_value" = "states")` will match `epi_df$geo_value`
2833
#' to `df$states`. To join by multiple variables, use a vector with length > 1.
2934
#' For example, `by = c("geo_value" = "states", "county" = "county")` will match
3035
#' `epi_df$geo_value` to `df$states` and `epi_df$county` to `df$county`.
3136
#'
32-
#' See [dplyr::left_join()] for more details.
37+
#' See [dplyr::inner_join()] for more details.
3338
#' @param df_pop_col the name of the column in the data frame `df` that
3439
#' contains the population data and will be used for scaling.
3540
#' This should be one column.
@@ -89,13 +94,25 @@ step_population_scaling <-
8994
suffix = "_scaled",
9095
skip = FALSE,
9196
id = rand_id("population_scaling")) {
92-
arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, id)
93-
arg_is_lgl(create_new, skip)
94-
arg_is_chr(df_pop_col, suffix, id)
97+
if (rlang::dots_n(...) == 0L) {
98+
cli_abort(c(
99+
"`...` must not be empty.",
100+
">" = "Please provide one or more tidyselect expressions in `...`
101+
specifying the columns to which scaling should be applied.",
102+
">" = "If you really want to list `step_population_scaling` in your
103+
recipe but not have it do anything, you can use a tidyselection
104+
that selects zero variables, such as `c()`."
105+
))
106+
}
107+
arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, skip, id)
108+
arg_is_chr(role, df_pop_col, suffix, id)
109+
hardhat::validate_column_names(df, df_pop_col)
95110
arg_is_chr(by, allow_null = TRUE)
111+
arg_is_numeric(rate_rescaling)
96112
if (rate_rescaling <= 0) {
97113
cli_abort("`rate_rescaling` must be a positive number.")
98114
}
115+
arg_is_lgl(create_new, skip)
99116

100117
recipes::add_step(
101118
recipe,
@@ -138,6 +155,42 @@ step_population_scaling_new <-
138155

139156
#' @export
140157
prep.step_population_scaling <- function(x, training, info = NULL, ...) {
158+
if (is.null(x$by)) {
159+
rhs_potential_keys <- setdiff(colnames(x$df), x$df_pop_col)
160+
lhs_potential_keys <- info %>%
161+
filter(role %in% c("geo_value", "key", "time_value")) %>%
162+
extract2("variable") %>%
163+
unique() # in case of weird var with multiple of above roles
164+
if (length(lhs_potential_keys) == 0L) {
165+
# We're working with a recipe and tibble, and *_role hasn't set up any of
166+
# the above roles. Let's say any column could actually act as a key, and
167+
# lean on `intersect` below to make this something reasonable.
168+
lhs_potential_keys <- names(training)
169+
}
170+
suggested_min_keys <- info %>%
171+
filter(role %in% c("geo_value", "key")) %>%
172+
extract2("variable") %>%
173+
unique()
174+
# (0 suggested keys if we weren't given any epikeytime var info.)
175+
x$by <- intersect(lhs_potential_keys, rhs_potential_keys)
176+
if (length(x$by) == 0L) {
177+
cli_stop(c(
178+
"Couldn't guess a default for `by`",
179+
">" = "Please rename columns in your population data to match those in your training data,
180+
or manually specify `by =` in `step_population_scaling()`."
181+
), class = "epipredict__step_population_scaling__default_by_no_intersection")
182+
}
183+
if (!all(suggested_min_keys %in% x$by)) {
184+
cli_warn(c(
185+
"{setdiff(suggested_min_keys, x$by)} {?was an/were} epikey column{?s} in the training data,
186+
but {?wasn't/weren't} found in the population `df`.",
187+
"i" = "Defaulting to join by {x$by}.",
188+
">" = "Double-check whether column names on the population `df` match those for your training data.",
189+
">" = "Consider using population data with breakdowns by {suggested_min_keys}.",
190+
">" = "Manually specify `by =` to silence."
191+
), class = "epipredict__step_population_scaling__default_by_missing_suggested_keys")
192+
}
193+
}
141194
step_population_scaling_new(
142195
terms = x$terms,
143196
role = x$role,
@@ -156,10 +209,14 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) {
156209

157210
#' @export
158211
bake.step_population_scaling <- function(object, new_data, ...) {
159-
object$by <- object$by %||% intersect(
160-
epi_keys_only(new_data),
161-
colnames(select(object$df, !object$df_pop_col))
162-
)
212+
if (is.null(object$by)) {
213+
cli::cli_abort(c(
214+
"`by` was not set and no default was filled in",
215+
">" = "If this was a fit recipe generated from an older version
216+
of epipredict that you loaded in from a file,
217+
please regenerate with the current version of epipredict."
218+
))
219+
}
163220
joinby <- list(x = names(object$by) %||% object$by, y = object$by)
164221
hardhat::validate_column_names(new_data, joinby$x)
165222
hardhat::validate_column_names(object$df, joinby$y)
@@ -177,7 +234,10 @@ bake.step_population_scaling <- function(object, new_data, ...) {
177234
suffix <- ifelse(object$create_new, object$suffix, "")
178235
col_to_remove <- setdiff(colnames(object$df), colnames(new_data))
179236

180-
left_join(new_data, object$df, by = object$by, suffix = c("", ".df")) %>%
237+
inner_join(new_data, object$df,
238+
by = object$by, relationship = "many-to-one", unmatched = c("error", "drop"),
239+
suffix = c("", ".df")
240+
) %>%
181241
mutate(
182242
across(
183243
all_of(object$columns),

R/utils-latency.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ drop_ignored_keys <- function(training, keys_to_ignore) {
359359
# note that the extra parenthesis black magic is described here: https://github.com/tidyverse/dplyr/issues/6194
360360
# and is needed to bypass an incomplete port of `across` functions to `if_any`
361361
training %>%
362+
ungroup() %>%
362363
filter((dplyr::if_all(
363364
names(keys_to_ignore),
364365
~ . %nin% keys_to_ignore[[cur_column()]]

0 commit comments

Comments
 (0)