Skip to content

Commit 9d70989

Browse files
authored
Merge pull request #421 from cmu-delphi/grf-arx-hotfix
Grf arx hotfix
2 parents ea34700 + 7135d7b commit 9d70989

File tree

8 files changed

+69
-22
lines changed

8 files changed

+69
-22
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ Remotes:
7272
cmu-delphi/epidatasets,
7373
cmu-delphi/epidatr,
7474
cmu-delphi/epiprocess,
75+
cmu-delphi/epidatasets,
7576
dajmcdon/smoothqr
7677
Config/testthat/edition: 3
7778
Encoding: UTF-8

R/arx_forecaster.R

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,26 @@ arx_fcast_epi_workflow <- function(
186186

187187
# --- postprocessor
188188
f <- frosting() %>% layer_predict() # %>% layer_naomit()
189-
if (inherits(trainer, "quantile_reg")) {
189+
is_quantile_reg <- inherits(trainer, "quantile_reg") |
190+
(inherits(trainer, "rand_forest") & trainer$engine == "grf_quantiles")
191+
if (is_quantile_reg) {
190192
# add all quantile_level to the forecaster and update postprocessor
191-
quantile_levels <- sort(compare_quantile_args(
192-
args_list$quantile_levels,
193-
rlang::eval_tidy(trainer$args$quantile_levels)
194-
))
193+
if (inherits(trainer, "quantile_reg")) {
194+
quantile_levels <- sort(compare_quantile_args(
195+
args_list$quantile_levels,
196+
rlang::eval_tidy(trainer$args$quantile_levels),
197+
"qr"
198+
))
199+
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
200+
} else {
201+
quantile_levels <- sort(compare_quantile_args(
202+
args_list$quantile_levels,
203+
rlang::eval_tidy(trainer$eng_args$quantiles) %||% c(.1, .5, .9),
204+
"grf"
205+
))
206+
trainer$eng_args$quantiles <- rlang::enquo(quantile_levels)
207+
}
195208
args_list$quantile_levels <- quantile_levels
196-
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
197209
f <- f %>%
198210
layer_quantile_distn(quantile_levels = quantile_levels) %>%
199211
layer_point_from_distn()
@@ -345,9 +357,13 @@ print.arx_fcast <- function(x, ...) {
345357
NextMethod(name = name, ...)
346358
}
347359

348-
compare_quantile_args <- function(alist, tlist) {
360+
compare_quantile_args <- function(alist, tlist, train_method = c("qr", "grf")) {
361+
train_method <- rlang::arg_match(train_method)
349362
default_alist <- eval(formals(arx_args_list)$quantile_levels)
350-
default_tlist <- eval(formals(quantile_reg)$quantile_levels)
363+
default_tlist <- switch(train_method,
364+
"qr" = eval(formals(quantile_reg)$quantile_levels),
365+
"grf" = c(.1, .5, .9)
366+
)
351367
if (setequal(alist, default_alist)) {
352368
if (setequal(tlist, default_tlist)) {
353369
return(sort(unique(union(alist, tlist))))

R/canned-epipred.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ print.canned_epipred <- function(x, name, ...) {
7777
fn_meta <- function() {
7878
cli::cli_ul()
7979
cli::cli_li("Geography: {.field {x$metadata$training$geo_type}},")
80-
if (!is.null(x$metadata$training$other_keys)) {
81-
cli::cli_li("Other keys: {.field {x$metadata$training$other_keys}},")
80+
other_keys <- x$metadata$training$other_keys
81+
if (!is.null(other_keys) && length(other_keys) > 0L) {
82+
cli::cli_li("Other keys: {.field {other_keys}},")
8283
}
8384
cli::cli_li("Time type: {.field {x$metadata$training$time_type}},")
8485
cli::cli_li("Using data up-to-date as of: {.field {format(x$metadata$training$as_of)}}.")

man/step_adjust_latency.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/arx_args_list.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@
124124

125125
# arx forecaster disambiguates quantiles
126126

127+
Code
128+
compare_quantile_args(alist / 10, 1:9 / 10, "grf")
129+
Condition
130+
Error in `compare_quantile_args()`:
131+
! You have specified different, non-default, quantiles in the trainier and `arx_args` options.
132+
i Please only specify quantiles in one location.
133+
134+
---
135+
127136
Code
128137
compare_quantile_args(alist, tlist)
129138
Condition

tests/testthat/_snaps/snapshots.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,7 +1093,6 @@
10931093
10941094
Training data was an <epi_df> with:
10951095
* Geography: state,
1096-
* Other keys: ,
10971096
* Time type: day,
10981097
* Using data up-to-date as of: 2022-05-31.
10991098
* With the last data available on 2021-12-31
@@ -1117,7 +1116,6 @@
11171116
11181117
Training data was an <epi_df> with:
11191118
* Geography: state,
1120-
* Other keys: ,
11211119
* Time type: day,
11221120
* Using data up-to-date as of: 2022-05-31.
11231121
* With the last data available on 2021-12-31
@@ -1142,7 +1140,6 @@
11421140
11431141
Training data was an <epi_df> with:
11441142
* Geography: state,
1145-
* Other keys: ,
11461143
* Time type: day,
11471144
* Using data up-to-date as of: 2022-05-31.
11481145
* With the last data available on 2021-12-31

tests/testthat/test-arx_args_list.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ test_that("arx forecaster disambiguates quantiles", {
4343
compare_quantile_args(alist, tlist),
4444
sort(c(alist, tlist))
4545
)
46+
expect_snapshot(
47+
error = TRUE,
48+
compare_quantile_args(alist / 10, 1:9 / 10, "grf")
49+
)
50+
expect_identical(compare_quantile_args(alist, 1:9 / 10, "grf"), 1:9 / 10)
4651
alist <- c(.5, alist)
4752
expect_identical( # tlist is default, should give alist
4853
compare_quantile_args(alist, tlist),

tests/testthat/test-grf_quantiles.R

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,30 @@ test_that("quantile_rand_forest handles allows setting the trees and mtry", {
5151
expect_identical(pars$`_num_trees`, manual$`_num_trees`)
5252
})
5353

54-
test_that("quantile_rand_forest predicts reasonable quantiles", {
54+
test_that("quantile_rand_forest operates with arx_forecaster", {
5555
spec <- rand_forest(mode = "regression") %>%
56-
set_engine("grf_quantiles", quantiles = c(.2, .5, .8))
57-
expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib))
58-
# swapping around the probabilities, because somehow this happens in practice,
59-
# but I'm not sure how to reproduce
60-
out$fit$quantiles.orig <- c(0.5, 0.9, 0.1)
61-
expect_no_error(predict(out, tib))
56+
set_engine("grf_quantiles", quantiles = c(.1, .2, .5, .8, .9)) # non-default
57+
expect_identical(rlang::eval_tidy(spec$eng_args$quantiles), c(.1, .2, .5, .8, .9))
58+
tib <- as_epi_df(tibble(time_value = 1:25, geo_value = "ca", value = rnorm(25)))
59+
o <- arx_fcast_epi_workflow(tib, "value", trainer = spec)
60+
spec2 <- parsnip::extract_spec_parsnip(o)
61+
expect_identical(
62+
rlang::eval_tidy(spec2$eng_args$quantiles),
63+
rlang::eval_tidy(spec$eng_args$quantiles)
64+
)
65+
spec <- rand_forest(mode = "regression", "grf_quantiles")
66+
expect_null(rlang::eval_tidy(spec$eng_args))
67+
o <- arx_fcast_epi_workflow(tib, "value", trainer = spec)
68+
spec2 <- parsnip::extract_spec_parsnip(o)
69+
expect_identical(
70+
rlang::eval_tidy(spec2$eng_args$quantiles),
71+
c(.05, .1, .5, .9, .95) # merged with arx_args default
72+
)
73+
df <- epidatasets::counts_subset %>% filter(time_value >= "2021-10-01")
74+
75+
z <- arx_forecaster(df, "cases", "cases", spec2)
76+
expect_identical(
77+
nested_quantiles(z$predictions$.pred_distn[1])[[1]]$quantile_levels,
78+
c(.05, .1, .5, .9, .95)
79+
)
6280
})

0 commit comments

Comments
 (0)