Skip to content

Commit a37df95

Browse files
committed
pass all tests
1 parent 82a7674 commit a37df95

File tree

6 files changed

+211
-132
lines changed

6 files changed

+211
-132
lines changed

R/arx_forecaster.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ arx_fcast_epi_workflow <- function(
200200
} else {
201201
quantile_levels <- sort(compare_quantile_args(
202202
args_list$quantile_levels,
203-
rlang::eval_tidy(trainer$eng_args$quantiles) %||% c(.1, .5, .9),
203+
rlang::eval_tidy(trainer$eng_args$quantiles) %||%
204+
c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95),
204205
"grf"
205206
))
206207
trainer$eng_args$quantiles <- rlang::enquo(quantile_levels)

tests/testthat/_snaps/snapshots.md

Lines changed: 194 additions & 121 deletions
Large diffs are not rendered by default.

tests/testthat/test-arx_args_list.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ test_that("arx forecaster disambiguates quantiles", {
4141
tlist <- eval(formals(quantile_reg)$quantile_levels)
4242
expect_identical( # both default
4343
compare_quantile_args(alist, tlist),
44-
c(0.05, 0.5, 0.95)
44+
c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)
4545
)
4646
expect_snapshot(
4747
error = TRUE,

tests/testthat/test-extract_argument.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,20 @@ test_that("layer argument extractor works", {
88
expect_snapshot(error = TRUE, extract_argument(f$layers[[1]], "layer_predict", "bubble"))
99
expect_identical(
1010
extract_argument(f$layers[[2]], "layer_residual_quantiles", "quantile_levels"),
11-
c(0.0275, 0.9750)
11+
c(0.0275, 0.5, 0.9750)
1212
)
1313

1414
expect_snapshot(error = TRUE, extract_argument(f, "layer_thresh", "quantile_levels"))
1515
expect_identical(
1616
extract_argument(f, "layer_residual_quantiles", "quantile_levels"),
17-
c(0.0275, 0.9750)
17+
c(0.0275, 0.5, 0.9750)
1818
)
1919

2020
wf <- epi_workflow(postprocessor = f)
2121
expect_snapshot(error = TRUE, extract_argument(epi_workflow(), "layer_residual_quantiles", "quantile_levels"))
2222
expect_identical(
2323
extract_argument(wf, "layer_residual_quantiles", "quantile_levels"),
24-
c(0.0275, 0.9750)
24+
c(0.0275, 0.5, 0.9750)
2525
)
2626

2727
expect_snapshot(error = TRUE, extract_argument(wf, "layer_predict", c("type", "opts")))

tests/testthat/test-grf_quantiles.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ test_that("quantile_rand_forest defaults work", {
99
spec <- rand_forest(engine = "grf_quantiles", mode = "regression")
1010
expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib))
1111
pars <- parsnip::extract_fit_engine(out)
12-
manual <- quantile_forest(as.matrix(tib[, 2:3]), tib$y, quantiles = c(0.1, 0.5, 0.9))
12+
manual <- quantile_forest(
13+
as.matrix(tib[, 2:3]), tib$y,
14+
quantiles = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)
15+
)
1316
expect_identical(pars$quantiles.orig, manual$quantiles.orig)
1417
expect_identical(pars$`_num_trees`, manual$`_num_trees`)
1518

@@ -43,7 +46,8 @@ test_that("quantile_rand_forest handles alternative quantiles", {
4346

4447

4548
test_that("quantile_rand_forest handles allows setting the trees and mtry", {
46-
spec <- rand_forest(mode = "regression", mtry = 2, trees = 100, engine = "grf_quantiles")
49+
spec <- rand_forest(mode = "regression", mtry = 2, trees = 100) %>%
50+
set_engine(engine = "grf_quantiles", quantiles = c(0.1, 0.5, 0.9))
4751
expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib))
4852
pars <- parsnip::extract_fit_engine(out)
4953
manual <- quantile_forest(as.matrix(tib[, 2:3]), tib$y, mtry = 2, num.trees = 100)
@@ -68,13 +72,13 @@ test_that("quantile_rand_forest operates with arx_forecaster", {
6872
spec2 <- parsnip::extract_spec_parsnip(o)
6973
expect_identical(
7074
rlang::eval_tidy(spec2$eng_args$quantiles),
71-
c(.05, .1, .5, .9, .95) # merged with arx_args default
75+
c(.05, .1, 0.25, .5, 0.75, .9, .95) # merged with arx_args default
7276
)
7377
df <- epidatasets::counts_subset %>% filter(time_value >= "2021-10-01")
7478

7579
z <- arx_forecaster(df, "cases", "cases", spec2)
7680
expect_identical(
7781
nested_quantiles(z$predictions$.pred_distn[1])[[1]]$quantile_levels,
78-
c(.05, .1, .5, .9, .95)
82+
c(.05, .1, 0.25, .5, 0.75, .9, .95)
7983
)
8084
})

tests/testthat/test-layer_threshold_preds.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ test_that("thresholds additional columns", {
5858
p <- p %>%
5959
dplyr::mutate(.quantiles = nested_quantiles(.pred_distn)) %>%
6060
tidyr::unnest(.quantiles)
61-
expect_equal(round(p$values, digits = 3), c(0.180, 0.31, 0.180, .18, 0.310, .31))
62-
expect_equal(p$quantile_levels, rep(c(.1, .9), times = 3))
61+
expect_equal(round(p$values, digits = 3),
62+
c(0.180, 0.180, 0.31, 0.180, 0.180, .18, 0.310, .31, .31))
63+
expect_equal(p$quantile_levels, rep(c(.1, 0.5, .9), times = 3))
6364
})

0 commit comments

Comments
 (0)