35
35
# ' @importFrom zeallot %<-%
36
36
# ' @importFrom recipes all_numeric
37
37
# ' @export
38
- scaled_pop_seasonal <- function (epi_data ,
39
- outcome ,
40
- extra_sources = " " ,
41
- ahead = 1 ,
42
- pop_scaling = TRUE ,
43
- drop_non_seasons = FALSE ,
44
- scale_method = c(" quantile" , " std" , " none" ),
45
- center_method = c(" median" , " mean" , " none" ),
46
- nonlin_method = c(" quart_root" , " none" ),
47
- seasonal_method = c(" none" , " flu" , " covid" , " indicator" , " window" , " climatological" ),
48
- seasonal_backward_window = 5 * 7 ,
49
- seasonal_forward_window = 3 * 7 ,
50
- train_residual = FALSE ,
51
- trainer = epipredict :: quantile_reg(),
52
- quantile_levels = covidhub_probs(),
53
- filter_source = " " ,
54
- filter_agg_level = " " ,
55
- clip_lower = TRUE ,
56
- ... ) {
38
+ scaled_pop_seasonal <- function (
39
+ epi_data ,
40
+ outcome ,
41
+ extra_sources = " " ,
42
+ ahead = 1 ,
43
+ pop_scaling = TRUE ,
44
+ drop_non_seasons = FALSE ,
45
+ scale_method = c(" quantile" , " std" , " none" ),
46
+ center_method = c(" median" , " mean" , " none" ),
47
+ nonlin_method = c(" quart_root" , " none" ),
48
+ seasonal_method = c(" none" , " flu" , " covid" , " indicator" , " window" , " climatological" ),
49
+ seasonal_backward_window = 5 * 7 ,
50
+ seasonal_forward_window = 3 * 7 ,
51
+ train_residual = FALSE ,
52
+ trainer = epipredict :: quantile_reg(),
53
+ quantile_levels = covidhub_probs(),
54
+ filter_source = " " ,
55
+ filter_agg_level = " " ,
56
+ clip_lower = TRUE ,
57
+ ...
58
+ ) {
57
59
scale_method <- arg_match(scale_method )
58
60
center_method <- arg_match(center_method )
59
61
nonlin_method <- arg_match(nonlin_method )
@@ -62,6 +64,9 @@ scaled_pop_seasonal <- function(epi_data,
62
64
extra_sources <- unwrap_argument(extra_sources )
63
65
trainer <- unwrap_argument(trainer )
64
66
67
+ if (typeof(seasonal_method ) == " list" ) {
68
+ seasonal_method <- seasonal_method [[1 ]]
69
+ }
65
70
if (all(seasonal_method == c(" none" , " flu" , " covid" , " indicator" , " window" , " climatological" ))) {
66
71
seasonal_method <- " none"
67
72
}
@@ -100,7 +105,8 @@ scaled_pop_seasonal <- function(epi_data,
100
105
args_list <- inject(default_args_list(!!! args_input ))
101
106
# if you want to hardcode particular predictors in a particular forecaster
102
107
predictors <- c(outcome , extra_sources )
103
- c(args_list , predictors , trainer ) %<- % sanitize_args_predictors_trainer(epi_data , outcome , predictors , trainer , args_list )
108
+ c(args_list , predictors , trainer ) %<- %
109
+ sanitize_args_predictors_trainer(epi_data , outcome , predictors , trainer , args_list )
104
110
105
111
if (" season_week" %nin % names(epi_data )) {
106
112
epi_data %<> % add_season_info()
@@ -116,13 +122,27 @@ scaled_pop_seasonal <- function(epi_data,
116
122
season_data <- epi_data
117
123
}
118
124
# TODO: Jank way to avoid having hhs_region get centered; this isn't very general
119
- learned_params <- calculate_whitening_params(season_data , setdiff(predictors , " hhs_region" ), scale_method , center_method , nonlin_method )
125
+ learned_params <- calculate_whitening_params(
126
+ season_data ,
127
+ setdiff(predictors , " hhs_region" ),
128
+ scale_method ,
129
+ center_method ,
130
+ nonlin_method
131
+ )
120
132
epi_data %<> % data_whitening(setdiff(predictors , " hhs_region" ), learned_params , nonlin_method )
121
133
122
134
# get the seasonal features
123
135
# first add PCA
124
136
if ((" flu" %in% seasonal_method ) || (" covid" %in% seasonal_method )) {
125
- epi_data <- compute_pca(epi_data , seasonal_method , ahead , scale_method , center_method , nonlin_method , normalize = train_residual )
137
+ epi_data <- compute_pca(
138
+ epi_data ,
139
+ seasonal_method ,
140
+ ahead ,
141
+ scale_method ,
142
+ center_method ,
143
+ nonlin_method ,
144
+ normalize = train_residual
145
+ )
126
146
127
147
if (train_residual ) {
128
148
epi_data <- epi_data %> % mutate(across(all_of(outcome ), ~ .x - PC1 ))
@@ -172,14 +192,15 @@ scaled_pop_seasonal <- function(epi_data,
172
192
# preprocessing supported by epipredict
173
193
preproc <- epi_recipe(epi_data )
174
194
if (pop_scaling ) {
175
- preproc %<> % step_population_scaling(
176
- all_of(predictors ),
177
- df = epidatasets :: state_census ,
178
- df_pop_col = " pop" ,
179
- create_new = FALSE ,
180
- rate_rescaling = 1e5 ,
181
- by = c(" geo_value" = " abbr" )
182
- )
195
+ preproc %<> %
196
+ step_population_scaling(
197
+ all_of(predictors ),
198
+ df = epidatasets :: state_census ,
199
+ df_pop_col = " pop" ,
200
+ create_new = FALSE ,
201
+ rate_rescaling = 1e5 ,
202
+ by = c(" geo_value" = " abbr" )
203
+ )
183
204
}
184
205
if (" indicator" %in% seasonal_method ) {
185
206
preproc %<> %
@@ -201,14 +222,16 @@ scaled_pop_seasonal <- function(epi_data,
201
222
postproc <- frosting()
202
223
postproc %<> % arx_postprocess(trainer , args_list )
203
224
if (pop_scaling ) {
204
- postproc %<> % layer_population_scaling(
205
- .pred , .pred_distn ,
206
- df = epidatasets :: state_census ,
207
- df_pop_col = " pop" ,
208
- create_new = FALSE ,
209
- rate_rescaling = 1e5 ,
210
- by = c(" geo_value" = " abbr" )
211
- )
225
+ postproc %<> %
226
+ layer_population_scaling(
227
+ .pred ,
228
+ .pred_distn ,
229
+ df = epidatasets :: state_census ,
230
+ df_pop_col = " pop" ,
231
+ create_new = FALSE ,
232
+ rate_rescaling = 1e5 ,
233
+ by = c(" geo_value" = " abbr" )
234
+ )
212
235
}
213
236
# with all the setup done, we execute and format
214
237
pred <- run_workflow_and_format(preproc , postproc , trainer , season_data , epi_data )
@@ -217,7 +240,10 @@ scaled_pop_seasonal <- function(epi_data,
217
240
# finally, any postprocessing not supported by epipredict e.g. calibration
218
241
#
219
242
# undo subtraction if we're training on residuals
220
- if (train_residual && ((" flu" %in% seasonal_method ) || (" covid" %in% seasonal_method ) || (" climatological" %in% seasonal_method ))) {
243
+ if (
244
+ train_residual &&
245
+ ((" flu" %in% seasonal_method ) || (" covid" %in% seasonal_method ) || (" climatological" %in% seasonal_method ))
246
+ ) {
221
247
pred <- pred %> %
222
248
mutate(epi_week = epiweek(target_end_date )) %> %
223
249
left_join(values_subtracted , by = join_by(geo_value , source , epi_week == epiweek )) %> %
@@ -228,7 +254,12 @@ scaled_pop_seasonal <- function(epi_data,
228
254
# reintroduce color into the value
229
255
pred_final <- pred %> %
230
256
rename({{ outcome }} : = value ) %> %
231
- data_coloring(outcome , learned_params , join_cols = key_colnames(epi_data , exclude = " time_value" ), nonlin_method = nonlin_method ) %> %
257
+ data_coloring(
258
+ outcome ,
259
+ learned_params ,
260
+ join_cols = key_colnames(epi_data , exclude = " time_value" ),
261
+ nonlin_method = nonlin_method
262
+ ) %> %
232
263
rename(value = {{ outcome }})
233
264
if (clip_lower ) {
234
265
pred_final %<> % mutate(value = pmax(0 , value ))
0 commit comments