Skip to content

Commit 90edb46

Browse files
committed
final requests
1 parent 86c46a4 commit 90edb46

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

R/arx_classifier.R

+7-2
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,14 @@ arx_class_epi_workflow <- function(
191191
}
192192
}
193193
# regex that will match any amount of adjustment for the ahead
194-
ahead_out_name <- glue::glue("ahead_[0-9]*_{pre_out_name}")
194+
ahead_out_name_regex <- glue::glue("ahead_[0-9]*_{pre_out_name}")
195195
method_adjust_latency <- args_list$adjust_latency
196196
if (method_adjust_latency != "none") {
197+
if (method_adjust_latency != "extend_ahead") {
198+
cli_abort("only extend_ahead is currently supported",
199+
class = "epipredict__arx_classifier__adjust_latency_unsupported_method"
200+
)
201+
}
197202
r <- r %>% step_adjust_latency(!!pre_out_name,
198203
fixed_forecast_date = forecast_date,
199204
method = method_adjust_latency
@@ -204,7 +209,7 @@ arx_class_epi_workflow <- function(
204209
r <- r %>%
205210
step_mutate(
206211
across(
207-
matches(ahead_out_name),
212+
matches(ahead_out_name_regex),
208213
~ cut(.x, breaks = args_list$breaks),
209214
.names = "outcome_class",
210215
.unpack = TRUE

tests/testthat/test-snapshots.R

+20
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,24 @@ test_that("arx_classifier snapshots", {
162162
args_list = arx_class_args_list(adjust_latency = "extend_ahead", forecast_date = max_date + 2)
163163
)
164164
expect_snapshot_tibble(arc2$predictions)
165+
expect_error(
166+
arc3 <- arx_classifier(
167+
case_death_rate_subset %>%
168+
dplyr::filter(time_value >= as.Date("2021-11-01")),
169+
"death_rate",
170+
c("case_rate", "death_rate"),
171+
args_list = arx_class_args_list(adjust_latency = "extend_lags", forecast_date = max_date + 2)
172+
),
173+
class = "epipredict__arx_classifier__adjust_latency_unsupported_method"
174+
)
175+
expect_error(
176+
arc4 <- arx_classifier(
177+
case_death_rate_subset %>%
178+
dplyr::filter(time_value >= as.Date("2021-11-01")),
179+
"death_rate",
180+
c("case_rate", "death_rate"),
181+
args_list = arx_class_args_list(adjust_latency = "locf", forecast_date = max_date + 2)
182+
),
183+
class = "epipredict__arx_classifier__adjust_latency_unsupported_method"
184+
)
165185
})

0 commit comments

Comments
 (0)