16
16
# ' inverting the existing scaling.
17
17
# ' @param by A (possibly named) character vector of variables to join by.
18
18
# '
19
- # ' If `NULL`, the default, the function will perform a natural join, using all
20
- # ' variables in common across the `epi_df` produced by the `predict()` call
21
- # ' and the user-provided dataset.
22
- # ' If columns in that `epi_df` and `df` have the same name (and aren't
23
- # ' included in `by`), `.df` is added to the one from the user-provided data
24
- # ' to disambiguate.
19
+ # ' If `NULL`, the default, the function will try to infer a reasonable set of
20
+ # ' columns. First, it will try to join by all variables in the training/test
21
+ # ' data with roles `"geo_value"`, `"key"`, or `"time_value"` that also appear in
22
+ # ' `df`; these roles are automatically set if you are using an `epi_df`, or you
23
+ # ' can use, e.g., `update_role`. If no such roles are set, it will try to
24
+ # ' perform a natural join, using variables in common between the training/test
25
+ # ' data and population data.
26
+ # '
27
+ # ' If columns in the training/testing data and `df` have the same name (and
28
+ # ' aren't included in `by`), a `.df` suffix is added to the one from the
29
+ # ' user-provided data to disambiguate.
25
30
# '
26
31
# ' To join by different variables on the `epi_df` and `df`, use a named vector.
27
32
# ' For example, `by = c("geo_value" = "states")` will match `epi_df$geo_value`
28
33
# ' to `df$states`. To join by multiple variables, use a vector with length > 1.
29
34
# ' For example, `by = c("geo_value" = "states", "county" = "county")` will match
30
35
# ' `epi_df$geo_value` to `df$states` and `epi_df$county` to `df$county`.
31
36
# '
32
- # ' See [dplyr::left_join ()] for more details.
37
+ # ' See [dplyr::inner_join ()] for more details.
33
38
# ' @param df_pop_col the name of the column in the data frame `df` that
34
39
# ' contains the population data and will be used for scaling.
35
40
# ' This should be one column.
@@ -89,13 +94,25 @@ step_population_scaling <-
89
94
suffix = " _scaled" ,
90
95
skip = FALSE ,
91
96
id = rand_id(" population_scaling" )) {
92
- arg_is_scalar(role , df_pop_col , rate_rescaling , create_new , suffix , id )
93
- arg_is_lgl(create_new , skip )
94
- arg_is_chr(df_pop_col , suffix , id )
97
+ if (rlang :: dots_n(... ) == 0L ) {
98
+ cli_abort(c(
99
+ " `...` must not be empty." ,
100
+ " >" = " Please provide one or more tidyselect expressions in `...`
101
+ specifying the columns to which scaling should be applied." ,
102
+ " >" = " If you really want to list `step_population_scaling` in your
103
+ recipe but not have it do anything, you can use a tidyselection
104
+ that selects zero variables, such as `c()`."
105
+ ))
106
+ }
107
+ arg_is_scalar(role , df_pop_col , rate_rescaling , create_new , suffix , skip , id )
108
+ arg_is_chr(role , df_pop_col , suffix , id )
109
+ hardhat :: validate_column_names(df , df_pop_col )
95
110
arg_is_chr(by , allow_null = TRUE )
111
+ arg_is_numeric(rate_rescaling )
96
112
if (rate_rescaling < = 0 ) {
97
113
cli_abort(" `rate_rescaling` must be a positive number." )
98
114
}
115
+ arg_is_lgl(create_new , skip )
99
116
100
117
recipes :: add_step(
101
118
recipe ,
@@ -138,6 +155,42 @@ step_population_scaling_new <-
138
155
139
156
# ' @export
140
157
prep.step_population_scaling <- function (x , training , info = NULL , ... ) {
158
+ if (is.null(x $ by )) {
159
+ rhs_potential_keys <- setdiff(colnames(x $ df ), x $ df_pop_col )
160
+ lhs_potential_keys <- info %> %
161
+ filter(role %in% c(" geo_value" , " key" , " time_value" )) %> %
162
+ extract2(" variable" ) %> %
163
+ unique() # in case of weird var with multiple of above roles
164
+ if (length(lhs_potential_keys ) == 0L ) {
165
+ # We're working with a recipe and tibble, and *_role hasn't set up any of
166
+ # the above roles. Let's say any column could actually act as a key, and
167
+ # lean on `intersect` below to make this something reasonable.
168
+ lhs_potential_keys <- names(training )
169
+ }
170
+ suggested_min_keys <- info %> %
171
+ filter(role %in% c(" geo_value" , " key" )) %> %
172
+ extract2(" variable" ) %> %
173
+ unique()
174
+ # (0 suggested keys if we weren't given any epikeytime var info.)
175
+ x $ by <- intersect(lhs_potential_keys , rhs_potential_keys )
176
+ if (length(x $ by ) == 0L ) {
177
+ cli_stop(c(
178
+ " Couldn't guess a default for `by`" ,
179
+ " >" = " Please rename columns in your population data to match those in your training data,
180
+ or manually specify `by =` in `step_population_scaling()`."
181
+ ), class = " epipredict__step_population_scaling__default_by_no_intersection" )
182
+ }
183
+ if (! all(suggested_min_keys %in% x $ by )) {
184
+ cli_warn(c(
185
+ " {setdiff(suggested_min_keys, x$by)} {?was an/were} epikey column{?s} in the training data,
186
+ but {?wasn't/weren't} found in the population `df`." ,
187
+ " i" = " Defaulting to join by {x$by}." ,
188
+ " >" = " Double-check whether column names on the population `df` match those for your training data." ,
189
+ " >" = " Consider using population data with breakdowns by {suggested_min_keys}." ,
190
+ " >" = " Manually specify `by =` to silence."
191
+ ), class = " epipredict__step_population_scaling__default_by_missing_suggested_keys" )
192
+ }
193
+ }
141
194
step_population_scaling_new(
142
195
terms = x $ terms ,
143
196
role = x $ role ,
@@ -156,10 +209,14 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) {
156
209
157
210
# ' @export
158
211
bake.step_population_scaling <- function (object , new_data , ... ) {
159
- object $ by <- object $ by %|| % intersect(
160
- epi_keys_only(new_data ),
161
- colnames(select(object $ df , ! object $ df_pop_col ))
162
- )
212
+ if (is.null(object $ by )) {
213
+ cli :: cli_abort(c(
214
+ " `by` was not set and no default was filled in" ,
215
+ " >" = " If this was a fit recipe generated from an older version
216
+ of epipredict that you loaded in from a file,
217
+ please regenerate with the current version of epipredict."
218
+ ))
219
+ }
163
220
joinby <- list (x = names(object $ by ) %|| % object $ by , y = object $ by )
164
221
hardhat :: validate_column_names(new_data , joinby $ x )
165
222
hardhat :: validate_column_names(object $ df , joinby $ y )
@@ -177,7 +234,10 @@ bake.step_population_scaling <- function(object, new_data, ...) {
177
234
suffix <- ifelse(object $ create_new , object $ suffix , " " )
178
235
col_to_remove <- setdiff(colnames(object $ df ), colnames(new_data ))
179
236
180
- left_join(new_data , object $ df , by = object $ by , suffix = c(" " , " .df" )) %> %
237
+ inner_join(new_data , object $ df ,
238
+ by = object $ by , relationship = " many-to-one" , unmatched = c(" error" , " drop" ),
239
+ suffix = c(" " , " .df" )
240
+ ) %> %
181
241
mutate(
182
242
across(
183
243
all_of(object $ columns ),
0 commit comments