Skip to content

Commit 036f6e2

Browse files
committed
Test impl epi_slide with refactor tools
1 parent a9b4ea1 commit 036f6e2

File tree

2 files changed

+188
-49
lines changed

2 files changed

+188
-49
lines changed

R/slide-refactor.R

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,73 @@ upstream_slide_to_simple_hop <- function(.f, ..., .in_colnames, .out_colnames, .
146146
)
147147
}
148148

149+
# ref_time_values_to_inp_ref_inds <- function(inp_tbl, ref_time_values) {
150+
# matches <- vec_match(ref_time_values, inp_tbl$time_value)
151+
# inp_ref_inds <- matches[!is.na(matches)]
152+
# inp_ref_inds
153+
# }
154+
155+
# complete_for_time_slide <- function(inp_tbl, inp_ref_inds, before_n_steps, after_n_steps) {
156+
# if (before_n_steps == Inf) {
157+
# # We need to get back to inp_tbl[1L,] from inp_tbl[min(inp_ref_inds),]
158+
# start_padding <- min(inp_ref_inds) - 1L
159+
# } else {
160+
# start_padding <- before_n_steps
161+
# }
162+
# end_padding <- after_n_steps
163+
# #
164+
165+
# slide_t_max <- out_t_max + after_n_steps * unit_step
166+
# slide_nrow <- time_delta_to_n_steps(slide_t_max - slide_t_min, time_type) + 1L
167+
# slide_time_values <- slide_t_min + 0L:(slide_nrow - 1L) * unit_step
168+
# slide_inp_backrefs <- vec_match(slide_time_values, inp_tbl$time_value)
169+
# }
170+
171+
ref_time_values_to_out_time_values <- function(inp_tbl, ref_time_values) {
172+
vec_set_intersect(inp_tbl$time_value, ref_time_values)
173+
}
174+
175+
slide_window <- function(inp_tbl, epikey, simple_hop, before_n_steps, after_n_steps, unit_step, time_type, out_time_values) {
176+
# TODO test whether origin time value stuff actually is helpful
177+
origin_time_value <- inp_tbl$time_value[[1L]]
178+
inp_ts <- time_minus_time_in_n_steps(inp_tbl$time_value, origin_time_value, time_type)
179+
out_ts <- time_minus_time_in_n_steps(out_time_values, origin_time_value, time_type)
180+
if (vec_size(out_ts) == 0L) {
181+
stop("FIXME TODO")
182+
} else {
183+
slide_ts <- seq(min(out_ts) - before_n_steps, max(out_ts) + after_n_steps) # TODO compare min/max vs. `[[`
184+
}
185+
slide_inp_backrefs <- vec_match(slide_ts, inp_ts)
186+
# TODO refactor to use a join if not using backrefs later anymore?
187+
#
188+
# TODO perf: try removing time_value column before slice?
189+
slide_tbl <- vec_slice(inp_tbl, slide_inp_backrefs)
190+
slide_tbl$time_value <- origin_time_value + slide_ts * unit_step
191+
192+
ref_inds <- vec_match(out_ts, slide_ts)
193+
out_tbl <- simple_hop(slide_tbl, epikey, ref_inds)
194+
out_tbl
195+
}
196+
197+
198+
199+
# # We should filter down the slide time values to ones in the input time values
200+
# # when preparing the output:
201+
# rows_should_keep1 <- !is.na(slide_inp_backrefs)
202+
# # We also need to apply the out_filter.
203+
# #
204+
# # TODO comments + test vs. just using inequality
205+
# rows_should_keep2 <- switch(out_filter_time_style,
206+
# range = vec_rep_each(
207+
# c(FALSE, TRUE, FALSE),
208+
# c(slide_start_padding_n, slide_nrow - slide_start_padding_n - after_n_steps, after_n_steps),
209+
# ),
210+
# set = vec_in(slide_time_values, out_time_values)
211+
# )
212+
# rows_should_keep <- rows_should_keep1 & rows_should_keep2
213+
# out_tbl <- vec_slice(slide, rows_should_keep)
214+
# out_tbl
215+
149216
# TODO maybe make ref_inds optional or have special handling if it's the whole sequence? But can it ever be the full sequence in the common fixed-width window case? Should be some truncation of it.
150217

151218
# TODO decide whether/where to put time range stuff

R/slide.R

Lines changed: 121 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -261,58 +261,130 @@ epi_slide <- function(
261261
# Check for duplicated time values within groups
262262
assert(check_ukey_unique(ungroup(.x), c(group_vars(.x), "time_value")))
263263

264-
# Begin handling completion. This will create a complete time index between
265-
# the smallest and largest time values in the data. This is used to ensure
266-
# that the slide function is called with a complete window of data. Each slide
267-
# group will filter this down to between its min and max time values. We also
268-
# mark which dates were in the data and which were added by our completion.
269-
date_seq_list <- full_date_seq(.x, window_args$before, window_args$after, time_type)
270-
.x$.real <- TRUE
264+
# # Begin handling completion. This will create a complete time index between
265+
# # the smallest and largest time values in the data. This is used to ensure
266+
# # that the slide function is called with a complete window of data. Each slide
267+
# # group will filter this down to between its min and max time values. We also
268+
# # mark which dates were in the data and which were added by our completion.
269+
# date_seq_list <- full_date_seq(.x, window_args$before, window_args$after, time_type)
270+
# .x$.real <- TRUE
271271

272-
# Create a wrapper that calculates and passes `.ref_time_value` to the
273-
# computation. `i` is contained in the `slide_comp_wrapper_factory`
274-
# environment such that when called within `slide_one_grp` `i` advances
275-
# through the list of reference time values within a group and then resets
276-
# back to 1 when switching groups.
277-
slide_comp_wrapper_factory <- function(kept_ref_time_values) {
278-
i <- 1L
279-
slide_comp_wrapper <- function(.x, .group_key, ...) {
280-
.ref_time_value <- kept_ref_time_values[[i]]
281-
i <<- i + 1L
282-
.slide_comp(.x, .group_key, .ref_time_value, ...)
272+
# # Create a wrapper that calculates and passes `.ref_time_value` to the
273+
# # computation. `i` is contained in the `slide_comp_wrapper_factory`
274+
# # environment such that when called within `slide_one_grp` `i` advances
275+
# # through the list of reference time values within a group and then resets
276+
# # back to 1 when switching groups.
277+
# slide_comp_wrapper_factory <- function(kept_ref_time_values) {
278+
# i <- 1L
279+
# slide_comp_wrapper <- function(.x, .group_key, ...) {
280+
# .ref_time_value <- kept_ref_time_values[[i]]
281+
# i <<- i + 1L
282+
# .slide_comp(.x, .group_key, .ref_time_value, ...)
283+
# }
284+
# slide_comp_wrapper
285+
# }
286+
287+
# # - If .x is not grouped, then the trivial group is applied:
288+
# # https://dplyr.tidyverse.org/reference/group_map.html
289+
# # - We create a lambda that forwards the necessary slide arguments to
290+
# # `epi_slide_one_group`.
291+
# # - `...` from top of `epi_slide` are forwarded to `.f` here through
292+
# # group_modify and through the lambda.
293+
# result <- group_map(
294+
# .x,
295+
# .f = function(.data_group, .group_key, ...) {
296+
# epi_slide_one_group(
297+
# .data_group, .group_key, ...,
298+
# .slide_comp_factory = slide_comp_wrapper_factory,
299+
# .before = window_args$before,
300+
# .after = window_args$after,
301+
# .ref_time_values = .ref_time_values,
302+
# .all_rows = .all_rows,
303+
# .new_col_name = .new_col_name,
304+
# .used_data_masking = used_data_masking,
305+
# .time_type = time_type,
306+
# .date_seq_list = date_seq_list
307+
# )
308+
# },
309+
# ...,
310+
# .keep = TRUE
311+
# ) %>%
312+
# list_rbind() %>%
313+
# `[`(.$.real, names(.) != ".real") %>%
314+
# arrange_col_canonical() %>%
315+
# group_by(!!!.x_orig_groups)
316+
before_n_steps <- time_delta_to_n_steps(window_args$before, time_type)
317+
after_n_steps <- time_delta_to_n_steps(window_args$after, time_type)
318+
unit_step <- unit_time_delta(time_type, format = "fast")
319+
simple_hop <- time_slide_to_simple_hop(.slide_comp = .slide_comp, ..., .before_n_steps = before_n_steps, .after_n_steps = after_n_steps)
320+
result <- .x %>%
321+
group_modify(function(grp_data, grp_key) {
322+
out_time_values <- ref_time_values_to_out_time_values(grp_data, .ref_time_values)
323+
res <- grp_data
324+
slide_values <- slide_window(grp_data, grp_key, simple_hop, before_n_steps, after_n_steps, unit_step, time_type, out_time_values)
325+
# FIXME check, de-dupe, simplify, refactor, ...
326+
if (.all_rows) {
327+
new_slide_values <- vec_cast(rep(NA, nrow(res)), slide_values)
328+
vec_slice(new_slide_values, vec_match(out_time_values, res$time_value)) <- slide_values
329+
slide_values <- new_slide_values
330+
} else {
331+
res <- vec_slice(res, vec_match(out_time_values, res$time_value))
332+
}
333+
334+
if (is.null(.new_col_name)) {
335+
if (inherits(slide_values, "data.frame")) {
336+
# Sometimes slide_values can parrot back columns already in `res`; allow
337+
# this, but balk if a column has the same name as one in `res` but a
338+
# different value:
339+
comp_nms <- names(slide_values)
340+
overlaps_existing_names <- comp_nms %in% names(res)
341+
for (comp_i in which(overlaps_existing_names)) {
342+
if (!identical(slide_values[[comp_i]], res[[comp_nms[[comp_i]]]])) {
343+
lines <- c(
344+
cli::format_error(c(
345+
"New column and old column clash",
346+
"x" = "slide computation output included a
347+
{format_varname(comp_nms[[comp_i]])} column, but `.x` already had a
348+
{format_varname(comp_nms[[comp_i]])} column with differing values",
349+
"Here are examples of differing values, where the grouping variables were
350+
{format_tibble_row(.group_key)}:"
351+
)),
352+
capture.output(print(waldo::compare(
353+
res[[comp_nms[[comp_i]]]], slide_values[[comp_i]],
354+
x_arg = rlang::expr_deparse(dplyr::expr(`$`(!!"existing", !!sym(comp_nms[[comp_i]])))), # nolint: object_usage_linter
355+
y_arg = rlang::expr_deparse(dplyr::expr(`$`(!!"comp_value", !!sym(comp_nms[[comp_i]])))) # nolint: object_usage_linter
356+
))),
357+
cli::format_message(c(
358+
">" = "You likely want to rename or remove this column from your slide
359+
computation's output, or debug why it has a different value."
360+
))
361+
)
362+
rlang::abort(paste(collapse = "\n", lines),
363+
class = "epiprocess__epi_slide_output_vs_existing_column_conflict"
364+
)
365+
}
366+
}
367+
# Unpack into separate columns (without name prefix). If there are
368+
# columns duplicating existing columns, de-dupe and order them as if they
369+
# didn't exist in slide_values.
370+
res <- dplyr::bind_cols(res, slide_values[!overlaps_existing_names])
371+
} else {
372+
# Apply default name (to vector or packed data.frame-type column):
373+
if ("slide_value" %in% names(res)) {
374+
cli_abort(c("Cannot guess a good column name for your output",
375+
"x" = "`slide_value` already exists in `.x`",
376+
">" = "Please provide a `.new_col_name`."
377+
))
378+
}
379+
res[["slide_value"]] <- slide_values
283380
}
284-
slide_comp_wrapper
381+
} else {
382+
# Vector or packed data.frame-type column (note: overlaps with existing
383+
# column names should already be forbidden by earlier validation):
384+
res[[.new_col_name]] <- slide_values
285385
}
286-
287-
# - If .x is not grouped, then the trivial group is applied:
288-
# https://dplyr.tidyverse.org/reference/group_map.html
289-
# - We create a lambda that forwards the necessary slide arguments to
290-
# `epi_slide_one_group`.
291-
# - `...` from top of `epi_slide` are forwarded to `.f` here through
292-
# group_modify and through the lambda.
293-
result <- group_map(
294-
.x,
295-
.f = function(.data_group, .group_key, ...) {
296-
epi_slide_one_group(
297-
.data_group, .group_key, ...,
298-
.slide_comp_factory = slide_comp_wrapper_factory,
299-
.before = window_args$before,
300-
.after = window_args$after,
301-
.ref_time_values = .ref_time_values,
302-
.all_rows = .all_rows,
303-
.new_col_name = .new_col_name,
304-
.used_data_masking = used_data_masking,
305-
.time_type = time_type,
306-
.date_seq_list = date_seq_list
307-
)
308-
},
309-
...,
310-
.keep = TRUE
311-
) %>%
312-
list_rbind() %>%
313-
`[`(.$.real, names(.) != ".real") %>%
314-
arrange_col_canonical() %>%
315-
group_by(!!!.x_orig_groups)
386+
res
387+
})
316388

317389
# If every group in epi_slide_one_group takes the
318390
# length(available_ref_time_values) == 0 branch then we end up here.

0 commit comments

Comments
 (0)