1
- # TODO add latency to default forecaster
2
1
# ' Direct autoregressive forecaster with covariates
3
2
# '
4
3
# ' This is an autoregressive forecasting model for
@@ -54,7 +53,7 @@ arx_forecaster <- function(
54
53
55
54
preds <- forecast(
56
55
wf ,
57
- fill_locf = TRUE ,
56
+ fill_locf = is.null( args_list $ adjust_latency ) ,
58
57
n_recent = args_list $ nafill_buffer ,
59
58
forecast_date = args_list $ forecast_date %|| % max(epi_data $ time_value )
60
59
) %> %
@@ -119,6 +118,17 @@ arx_fcast_epi_workflow <- function(
119
118
if (! (is.null(trainer ) || is_regression(trainer ))) {
120
119
cli :: cli_abort(" {trainer} must be a `{parsnip}` model of mode 'regression'." )
121
120
}
121
+ # forecast_date is first what they set;
122
+ # if they don't and they're not adjusting latency, it defaults to the max time_value
123
+ # if they're adjusting as_of, it defaults to the as_of
124
+ latency_adjust_fd <- if (is.null(args_list $ adjust_latency )) {
125
+ max(epi_data $ time_value )
126
+ } else {
127
+ attributes(epi_data )$ metadata $ as_of
128
+ }
129
+ forecast_date <- args_list $ forecast_date %|| % latency_adjust_fd
130
+ target_date <- args_list $ target_date %|| % (forecast_date + args_list $ ahead )
131
+
122
132
lags <- arx_lags_validator(predictors , args_list $ lags )
123
133
124
134
# --- preprocessor
@@ -128,26 +138,34 @@ arx_fcast_epi_workflow <- function(
128
138
r <- step_epi_lag(r , !! p , lag = lags [[l ]])
129
139
}
130
140
r <- r %> %
131
- step_epi_ahead(!! outcome , ahead = args_list $ ahead ) %> %
132
- step_epi_naomit() %> %
133
- step_training_window(n_recent = args_list $ n_training ) %> %
134
- {
135
- if (! is.null(args_list $ check_enough_data_n )) {
136
- check_enough_train_data(
137
- . ,
138
- all_predictors(),
139
- !! outcome ,
140
- n = args_list $ check_enough_data_n ,
141
- epi_keys = args_list $ check_enough_data_epi_keys ,
142
- drop_na = FALSE
143
- )
144
- } else {
145
- .
146
- }
141
+ step_epi_ahead(!! outcome , ahead = args_list $ ahead )
142
+ method <- args_list $ adjust_latency
143
+ if (! is.null(method )) {
144
+ if (method == " extend_ahead" ) {
145
+ r <- r %> % step_adjust_latency(all_outcomes(),
146
+ fixed_forecast_date = forecast_date ,
147
+ method = method
148
+ )
149
+ } else if (method == " extend_lags" ) {
150
+ r <- r %> % step_adjust_latency(all_predictors(),
151
+ fixed_forecast_date = forecast_date ,
152
+ method = method
153
+ )
147
154
}
155
+ }
156
+ r <- r %> %
157
+ step_epi_naomit() %> %
158
+ step_training_window(n_recent = args_list $ n_training )
159
+ if (! is.null(args_list $ check_enough_data_n )) {
160
+ r <- r %> % check_enough_train_data(
161
+ all_predictors(),
162
+ !! outcome ,
163
+ n = args_list $ check_enough_data_n ,
164
+ epi_keys = args_list $ check_enough_data_epi_keys ,
165
+ drop_na = FALSE
166
+ )
167
+ }
148
168
149
- forecast_date <- args_list $ forecast_date %|| % max(epi_data $ time_value )
150
- target_date <- args_list $ target_date %|| % (forecast_date + args_list $ ahead )
151
169
152
170
# --- postprocessor
153
171
f <- frosting() %> % layer_predict() # %>% layer_naomit()
@@ -159,11 +177,11 @@ arx_fcast_epi_workflow <- function(
159
177
))
160
178
args_list $ quantile_levels <- quantile_levels
161
179
trainer $ args $ quantile_levels <- rlang :: enquo(quantile_levels )
162
- f <- layer_quantile_distn(f , quantile_levels = quantile_levels ) %> %
180
+ f <- f %> %
181
+ layer_quantile_distn(quantile_levels = quantile_levels ) %> %
163
182
layer_point_from_distn()
164
183
} else {
165
- f <- layer_residual_quantiles(
166
- f ,
184
+ f <- f %> % layer_residual_quantiles(
167
185
quantile_levels = args_list $ quantile_levels ,
168
186
symmetrize = args_list $ symmetrize ,
169
187
by_key = args_list $ quantile_by_key
@@ -189,10 +207,15 @@ arx_fcast_epi_workflow <- function(
189
207
# ' @param n_training Integer. An upper limit for the number of rows per
190
208
# ' key that are used for training
191
209
# ' (in the time unit of the `epi_df`).
192
- # ' @param forecast_date Date. The date on which the forecast is created.
193
- # ' The default `NULL` will attempt to determine this automatically.
194
- # ' @param target_date Date. The date for which the forecast is intended.
195
- # ' The default `NULL` will attempt to determine this automatically.
210
+ # ' @param forecast_date Date. The date on which the forecast is created. The
211
+ # ' default `NULL` will attempt to determine this automatically either as the
212
+ # ' max time value if there is no latency adjustment, or as the `as_of` of
213
+ # ' `epi_data` if `adjust_latency` is non-`NULL`.
214
+ # ' @param target_date Date. The date for which the forecast is intended. The
215
+ # ' default `NULL` will attempt to determine this automatically as
216
+ # ' `forecast_date + ahead`.
217
+ # ' @param adjust_latency Character or `NULL`. one of the `method`s of
218
+ # ' `step_adjust_latency`, or `NULL` (in which case there is no adjustment).
196
219
# ' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
197
220
# ' prediction intervals. These are created by computing the quantiles of
198
221
# ' training residuals. A `NULL` value will result in point forecasts only.
@@ -238,6 +261,7 @@ arx_args_list <- function(
238
261
n_training = Inf ,
239
262
forecast_date = NULL ,
240
263
target_date = NULL ,
264
+ adjust_latency = NULL ,
241
265
quantile_levels = c(0.05 , 0.95 ),
242
266
symmetrize = TRUE ,
243
267
nonneg = TRUE ,
@@ -253,7 +277,7 @@ arx_args_list <- function(
253
277
254
278
arg_is_scalar(ahead , n_training , symmetrize , nonneg )
255
279
arg_is_chr(quantile_by_key , allow_empty = TRUE )
256
- arg_is_scalar(forecast_date , target_date , allow_null = TRUE )
280
+ arg_is_scalar(forecast_date , target_date , adjust_latency , allow_null = TRUE )
257
281
arg_is_date(forecast_date , target_date , allow_null = TRUE )
258
282
arg_is_nonneg_int(ahead , lags )
259
283
arg_is_lgl(symmetrize , nonneg )
@@ -282,6 +306,7 @@ arx_args_list <- function(
282
306
quantile_levels ,
283
307
forecast_date ,
284
308
target_date ,
309
+ adjust_latency ,
285
310
symmetrize ,
286
311
nonneg ,
287
312
max_lags ,
0 commit comments