diff --git a/DESCRIPTION b/DESCRIPTION index 882df32c..a3a4ceaf 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -54,16 +54,20 @@ Imports: gridExtra, data.table, methods, - dplyr + dplyr, + cli, + purrr, + assertthat Suggests: lme4, httr, tibble, - testthat, + testthat (>= 3.0.0), e1071, DescTools, DSOpal, DSMolgenisArmadillo, - DSLite + DSLite, + assertthat RoxygenNote: 7.3.2 Encoding: UTF-8 diff --git a/NAMESPACE b/NAMESPACE index d737d5e6..59da737b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -104,6 +104,7 @@ export(ds.seq) export(ds.setSeed) export(ds.skewness) export(ds.sqrt) +export(ds.standardiseDf) export(ds.summary) export(ds.table) export(ds.table1D) @@ -117,6 +118,24 @@ export(ds.var) export(ds.vectorCalc) import(DSI) import(data.table) +import(dplyr) +importFrom(DSI,datashield.aggregate) +importFrom(DSI,datashield.assign) +importFrom(assertthat,assert_that) +importFrom(cli,cli_abort) +importFrom(cli,cli_alert_danger) +importFrom(cli,cli_alert_info) +importFrom(cli,cli_alert_success) +importFrom(cli,cli_alert_warning) +importFrom(cli,cli_end) +importFrom(cli,cli_li) +importFrom(cli,cli_ol) +importFrom(cli,cli_text) +importFrom(cli,cli_ul) +importFrom(purrr,map) +importFrom(purrr,map_lgl) +importFrom(purrr,pmap) +importFrom(purrr,pmap_lgl) importFrom(stats,as.formula) importFrom(stats,na.omit) importFrom(stats,ts) diff --git a/R/ds.standardiseDf.R b/R/ds.standardiseDf.R new file mode 100644 index 00000000..4ba24f90 --- /dev/null +++ b/R/ds.standardiseDf.R @@ -0,0 +1,602 @@ +#' Fill DataFrame with Missing Columns and Adjust Classes +#' +#' This function fills a given DataFrame by adding missing columns, ensuring consistent column classes, and adjusting factor levels where necessary. +#' It performs checks to detect class and factor level conflicts and prompts the user for decisions to resolve these conflicts. +#' +#' @param df.name Name of the input DataFrame to fill. +#' @param newobj Name of the new DataFrame object created after filling. +#' @param fix_class Character, determines behaviour if class of variables is not the same in all +#' studies. Option "ask" (default) provides the user with a prompt asking if they want to set the +#' class across all studies, option "no" will throw an error if class conflicts are present. +#' @param fix_levels Character, determines behaviour if levels of factor variables is not the same +#' in all studies. Option "ask" (default) provides the user with a prompt asking if they want to set +#' the levels of factor variables to be the same across all studies, whilst option "no" will throw +#' an error if factor variables do not have the same class. +#' @param datasources Data sources from which to aggregate data. Default is `NULL`. +#' @importFrom assertthat assert_that +#' @importFrom DSI datashield.aggregate datashield.assign +#' @return The filled DataFrame with added columns and adjusted classes or factor levels. +#' @export +ds.standardiseDf <- function(df.name = NULL, newobj = NULL, fix_class = "ask", fix_levels = "ask", + datasources = NULL) { + fill_warnings <- list() + + .check_arguments(df.name, newobj, fix_class, fix_levels) + + if(is.null(datasources)){ + datasources <- datashield.connections_find() + } + + col_names <- datashield.aggregate(datasources, call("colnamesDS", df.name)) + .stop_if_cols_identical(col_names) + + var_classes <- .get_var_classes(df.name, datasources) + class_conflicts <- .identify_class_conflicts(var_classes) + + datashield.assign(datasources, newobj, as.symbol(df.name)) + + if (length(class_conflicts) > 0 & fix_class == "no") { + DSI::datashield.aggregate(datasources, call("rmDS", newobj)) + cli_abort("Variables do not have the same class in all studies and `fix_class` is 'no'") + } else if (length(class_conflicts) > 0 & fix_class == "ask") { + class_decisions <- prompt_user_class_decision_all_vars( + names(class_conflicts), + var_classes$server, + dplyr::select(var_classes, all_of(names(class_conflicts))), + newobj, + datasources + ) + + withCallingHandlers({ + .fix_classes(newobj, names(class_conflicts), class_decisions, newobj, datasources) + }, warning = function(w) { + fill_warnings <<- c(fill_warnings, conditionMessage(w)) # Append warning to the list + invokeRestart("muffleWarning") # Suppress immediate display of the warning + }) + } + + unique_cols <- .get_unique_cols(col_names) + .add_missing_cols_to_df(newobj, unique_cols, newobj, datasources) + new_names <- datashield.aggregate(datasources, call("colnamesDS", newobj)) + added_cols <- .get_added_cols(col_names, new_names) + + new_classes <- .get_var_classes(newobj, datasources) + factor_vars <- .identify_factor_vars(new_classes) + factor_levels <- .get_factor_levels(factor_vars, newobj, datasources) + level_conflicts <- .identify_level_conflicts(factor_levels) + + if (length(level_conflicts) > 0 & fix_levels == "no") { + DSI::datashield.aggregate(datasources, call("rmDS", newobj)) + cli_abort("Factor variables do not have the same levels in all studies and `fix_levels` is 'no'") + } else if (length(level_conflicts) > 0 & fix_levels == "ask") { + levels_decision <- ask_question_wait_response_levels(level_conflicts, newobj, datasources) + } + + if (levels_decision == "1") { + unique_levels <- .get_unique_levels(factor_levels, level_conflicts) + .set_factor_levels(newobj, unique_levels, datasources) + } + + .print_out_messages(added_cols, class_decisions, names(class_conflicts), unique_levels, + level_conflicts, levels_decision, newobj) + + .handle_warnings(fill_warnings) + .print_class_warning(class_conflicts, fix_class, class_decisions) +} + +#' Check Function Arguments for Validity +#' +#' This function validates the arguments provided to ensure they meet specified conditions. +#' It checks that the `fix_class` and `fix_levels` arguments are set to accepted values +#' and that `df.name` and `newobj` are character strings. +#' +#' @param df.name A character string representing the name of the data frame. +#' @param newobj A character string representing the name of the new object to be created. +#' @param fix_class A character string indicating the method for handling class issues. +#' Must be either `"ask"` or `"no"`. +#' @param fix_levels A character string indicating the method for handling level issues. +#' Must be either `"ask"` or `"no"`. +#' @return NULL. This function is used for validation and does not return a value. +#' @importFrom assertthat assert_that +#' @noRd +.check_arguments <- function(df.name, newobj, fix_class, fix_levels) { + assert_that(fix_class %in% c("ask", "no")) + assert_that(fix_levels %in% c("ask", "no")) + assert_that(is.character(df.name)) + assert_that(is.character(newobj)) +} + +#' Stop If Columns Are Identical +#' +#' Checks if the columns in the data frames are identical and throws an error if they are. +#' +#' @param col_names A list of column names from different data sources. +#' @return None. Throws an error if columns are identical. +#' @importFrom cli cli_abort +#' @noRd +.stop_if_cols_identical <- function(col_names) { + are_identical <- all(sapply(col_names, identical, col_names[[1]])) + if (are_identical) { + cli_abort("Columns are identical in all data frames: nothing to fill") + } +} + +#' Get Variable Classes from DataFrame +#' +#' Retrieves the class of each variable in the specified DataFrame from different data sources. +#' +#' @param df.name Name of the input DataFrame. +#' @param datasources Data sources from which to aggregate data. +#' @return A DataFrame containing the variable classes from each data source. +#' @import dplyr +#' @noRd +.get_var_classes <- function(df.name, datasources) { + cally <- call("getClassAllColsDS", df.name) + classes <- datashield.aggregate(datasources, cally) %>% + bind_rows(.id = "server") + return(classes) +} + +#' Identify Class Conflicts +#' +#' Identifies conflicts in variable classes across different data sources. +#' +#' @param classes A DataFrame containing variable classes across data sources. +#' @return A list of variables that have class conflicts. +#' @import dplyr +#' @importFrom purrr map +#' @noRd +.identify_class_conflicts <- function(classes) { + server <- NULL + different_class <- classes |> + dplyr::select(-server) |> + map(~ unique(na.omit(.))) + + out <- different_class[which(different_class %>% map(length) > 1)] + return(out) +} + +#' Prompt User for Class Decision for All Variables +#' +#' Prompts the user to resolve class conflicts for all variables. +#' +#' @param vars A vector of variable names with class conflicts. +#' @param all_servers The names of all servers. +#' @param all_classes The classes of the variables across servers. +#' @return A vector of decisions for each variable's class. +#' @noRd +prompt_user_class_decision_all_vars <- function(vars, all_servers, all_classes, newobj, datasources) { + decisions <- c() + for (i in 1:length(vars)) { + decisions[i] <- prompt_user_class_decision(vars[i], all_servers, all_classes[[i]], newobj, datasources) + } + return(decisions) +} + +#' Prompt User for Class Decision for a Single Variable +#' +#' Prompts the user to resolve a class conflict for a single variable. +#' +#' @param var The variable name with a class conflict. +#' @param all_servers The names of all servers. +#' @param all_classes The classes of the variable across servers. +#' @importFrom cli cli_alert_warning cli_alert_danger +#' @return A decision for the variable's class. +#' @noRd +prompt_user_class_decision <- function(var, servers, classes, newobj, datasources) { + cli_alert_warning("`ds.dataFrameFill` requires that all columns have the same class.") + cli_alert_danger("Column {.strong {var}} has following classes:") + print_all_classes(servers, classes) + cli_text("") + return(ask_question_wait_response_class(var, newobj, datasources)) +} + +#' Print All Server-Class Pairs +#' +#' This function prints out a list of server names along with their corresponding +#' class types. It formats the output with a bullet-point list using the `cli` package. +#' +#' @param all_servers A character vector containing the names of servers. +#' @param all_classes A character vector containing the class types corresponding +#' to each server. +#' @return This function does not return a value. It prints the server-class pairs +#' to the console as a bulleted list. +#' @importFrom cli cli_ul cli_li cli_end +#' @noRd +print_all_classes <- function(all_servers, all_classes) { + combined <- paste(all_servers, all_classes, sep = ": ") + cli_ul() + for (i in 1:length(combined)) { + cli_li("{combined[i]}") + } + cli_end() +} + +#' Ask Question and Wait for Class Response +#' +#' Prompts the user with a question and waits for a response related to class decisions. +#' +#' @param question The question to ask the user. +#' @return The user's decision. +#' @importFrom cli cli_text cli_alert_warning cli_abort +#' @noRd +ask_question_wait_response_class <- function(var, newobj, datasources) { + readline <- NULL + ask_question_class(var) + answer <- readline() + if (answer == "6") { + DSI::datashield.aggregate(datasources, call("rmDS", newobj)) + cli_abort("Aborted `ds.dataFrameFill`", .call = NULL) + } else if (!answer %in% as.character(1:5)) { + cli_text("") + cli_alert_warning("Invalid input. Please try again.") + cli_text("") + ask_question_wait_response_class(var, newobj, datasources) + } else { + return(answer) + } +} + +#' Prompt User for Class Conversion Options +#' +#' This function prompts the user with options to convert a variable to a specific class (e.g., factor, integer, numeric, character, or logical). +#' The function provides a list of class conversion options for the specified variable and includes an option to cancel the operation. +#' +#' @param var The name of the variable for which the user is prompted to select a class conversion option. +#' +#' @importFrom cli cli_alert_info cli_ol +#' @return None. This function is used for prompting the user and does not return a value. +#' @examples +#' ask_question("variable_name") +#' @noRd +ask_question_class <- function(var) { + cli_alert_info("Would you like to:") + class_options <- c("a factor", "an integer", "numeric", "a character", "a logical vector") + class_message <- paste0("Convert `{var}` to ", class_options, " in all studies") + cli_ol( + c(class_message, "Cancel `ds.dataFrameFill` operation") + ) +} + +#' Fix Variable Classes +#' +#' Applies the user's class decisions to fix the classes of variables across different data sources. +#' +#' @param df.name The name of the DataFrame. +#' @param different_classes A list of variables with class conflicts. +#' @param class_decisions The decisions made by the user. +#' @param newobj The name of the new DataFrame. +#' @param datasources Data sources from which to aggregate data. +#' @return None. Updates the DataFrame with consistent variable classes. +#' @noRd +.fix_classes <- function(df.name, different_classes, class_decisions, newobj, datasources) { + cally <- call("fixClassDS", df.name, different_classes, class_decisions) + datashield.assign(datasources, newobj, cally) +} + +#' Get Unique Columns from Data Sources +#' +#' Retrieves all unique columns from the data sources. +#' +#' @param col_names A list of column names. +#' @return A vector of unique column names. +#' @noRd +.get_unique_cols <- function(col_names) { + return( + unique( + unlist(col_names) + ) + ) +} + +#' Add Missing Columns to DataFrame +#' +#' Adds any missing columns to the DataFrame to ensure all columns are present across data sources. +#' +#' @param df.name The name of the DataFrame. +#' @param unique_cols A vector of unique column names. +#' @param newobj The name of the new DataFrame. +#' @param datasources Data sources from which to aggregate data. +#' @return None. Updates the DataFrame with added columns. +#' @noRd +.add_missing_cols_to_df <- function(df.name, cols_to_add_if_missing, newobj, datasources) { + cally <- call("fixColsDS", df.name, cols_to_add_if_missing) + datashield.assign(datasources, newobj, cally) +} + +#' Get Added Columns +#' +#' Compares the old and new column names and identifies newly added columns. +#' +#' @param old_names A list of old column names. +#' @param new_names A list of new column names. +#' @importFrom purrr pmap +#' @return A list of added column names. +#' @noRd +.get_added_cols <- function(old_names, new_names) { + list(old_names, new_names) %>% + pmap(function(.x, .y) { + .y[!.y %in% .x] + }) +} + +#' Identify Factor Variables +#' +#' Identifies which variables are factors in the DataFrame. +#' +#' @param var_classes A DataFrame containing variable classes. +#' @return A vector of factor variables. +#' @noRd +.identify_factor_vars <- function(var_classes) { + return( + var_classes %>% + dplyr::filter(row_number() == 1) %>% + dplyr::select(where(~ . == "factor")) + ) +} + +#' Get Factor Levels from Data Sources +#' +#' Retrieves the levels of factor variables from different data sources. +#' +#' @param factor_vars A vector of factor variables. +#' @param newobj The name of the new DataFrame. +#' @param datasources Data sources from which to aggregate data. +#' @return A list of factor levels. +#' @noRd +.get_factor_levels <- function(factor_vars, df, datasources) { + cally <- call("getAllLevelsDS", df, names(factor_vars)) + return(datashield.aggregate(datasources, cally)) +} + +#' Identify Factor Level Conflicts +#' +#' Identifies conflicts in factor levels across different data sources. +#' +#' @param factor_levels A list of factor levels. +#' @return A list of variables with level conflicts. +#' @importFrom purrr map_lgl pmap_lgl +#' @noRd +.identify_level_conflicts <- function(factor_levels) { + levels <- factor_levels %>% + pmap_lgl(function(...) { + args <- list(...) + !all(map_lgl(args[-1], ~ identical(.x, args[[1]]))) + }) + + return(names(levels[levels == TRUE])) +} + +#' Ask Question and Wait for Response on Factor Levels +#' +#' Prompts the user with options for resolving factor level conflicts and waits for a response. +#' +#' @param level_conflicts A list of variables with factor level conflicts. +#' @return The user's decision. +#' @noRd +ask_question_wait_response_levels <- function(level_conflicts, newobj, datasources) { + .make_levels_message(level_conflicts) + answer <- readline() + if (answer == "3") { + DSI::datashield.aggregate(datasources, call("rmDS", newobj)) + cli_abort("Aborted `ds.dataFrameFill`", .call = NULL) + } else if (!answer %in% as.character(1:2)) { + cli_alert_warning("Invalid input. Please try again.") + cli_alert_info("") + .make_levels_message(level_conflicts) + return(ask_question_wait_response_levels(level_conflicts, newobj, datasources)) + } else { + return(answer) + } +} + +#' Make Factor Level Conflict Message +#' +#' Creates a message to alert the user about factor level conflicts and prompt for action. +#' +#' @param level_conflicts A list of variables with factor level conflicts. +#' @importFrom cli cli_alert_warning cli_alert_info cli_ol +#' @return None. Prints the message to the console. +#' @noRd +.make_levels_message <- function(level_conflicts) { + cli_alert_warning("Warning: factor variables {level_conflicts} do not have the same levels in all studies") + cli_alert_info("Would you like to:") + cli_ol(c("Create the missing levels where they are not present", "Do nothing", "Cancel `ds.dataFrameFill` operation")) +} + +#' Get Unique Factor Levels +#' +#' Retrieves the unique factor levels for variables with conflicts. +#' +#' @param factor_levels A list of factor levels. +#' @param level_conflicts A list of variables with level conflicts. +#' @importFrom purrr pmap +#' @return A list of unique factor levels. +#' @noRd +.get_unique_levels <- function(factor_levels, level_conflicts) { + unique_levels <- factor_levels %>% + map(~ .[level_conflicts]) %>% + pmap(function(...) { + as.character(c(...)) + }) %>% + map(~ unique(.)) + return(unique_levels) +} + +#' Set Factor Levels in DataFrame +#' +#' Applies the unique factor levels to the DataFrame. +#' +#' @param newobj The name of the new DataFrame. +#' @param unique_levels A list of unique factor levels. +#' @param datasources Data sources from which to aggregate data. +#' @return None. Updates the DataFrame with the new factor levels. +#' @noRd +.set_factor_levels <- function(df, unique_levels, datasources) { + cally <- call("fixLevelsDS", df, names(unique_levels), unique_levels) + datashield.assign(datasources, df, cally) +} + +#' Print Out Summary Messages +#' +#' Prints summary messages regarding the filled DataFrame, including added columns, class decisions, and factor level adjustments. +#' +#' @param added_cols A list of added columns. +#' @param class_decisions A vector of class decisions. +#' @param different_classes A list of variables with class conflicts. +#' @param unique_levels A list of unique factor levels. +#' @param level_conflicts A list of variables with level conflicts. +#' @param levels_decision The decision made regarding factor levels. +#' @param newobj The name of the new DataFrame. +#' @importFrom cli cli_text +#' @return None. Prints messages to the console. +#' @noRd +.print_out_messages <- function(added_cols, class_decisions, different_classes, unique_levels, + level_conflicts, levels_decision, newobj) { + .print_var_recode_message(added_cols, newobj) + + if (length(different_classes) > 0) { + .print_class_recode_message(class_decisions, different_classes, newobj) + cli_text("") + } + + if (length(level_conflicts) > 0 & levels_decision == "1") { + .print_levels_recode_message(unique_levels, newobj) + } +} + +#' Print Variable Recode Message +#' +#' Prints a message summarizing the columns that were added to the DataFrame. +#' +#' @param added_cols A list of added columns. +#' @param newobj The name of the new DataFrame. +#' @importFrom cli cli_text +#' @return None. Prints the message to the console. +#' @noRd +.print_var_recode_message <- function(added_cols, newobj) { + cli_alert_success("The following variables have been added to {newobj}:") + added_cols_neat <- added_cols %>% map(~ ifelse(length(.) == 0, "", .)) + var_message <- paste0(names(added_cols), " --> ", added_cols_neat) + for (i in 1:length(var_message)) { + cli_alert_info("{var_message[[i]]}") + } + cli_text("") +} + +#' Print Class Recode Message +#' +#' Prints a message summarizing the class decisions that were made for variables with conflicts. +#' +#' @param class_decisions A vector of class decisions. +#' @param different_classes A list of variables with class conflicts. +#' @param newobj The name of the new DataFrame. +#' @importFrom cli cli_alert_info cli_alert_success +#' @return None. Prints the message to the console. +#' @noRd +.print_class_recode_message <- function(class_decisions, different_classes, newobj) { + choice_neat <- .change_choice_to_string(class_decisions) + class_message <- paste0(different_classes, " --> ", choice_neat) + cli_alert_success("The following classes have been set for all datasources in {newobj}: ") + for (i in 1:length(class_message)) { + cli_alert_info("{class_message[[i]]}") + } +} + +#' Convert Class Decision Code to String +#' +#' This function converts a numeric class decision input (represented as a string) +#' into the corresponding class type string (e.g., "factor", "integer", "numeric", etc.). +#' @param class_decision A string representing the class decision. It should be +#' one of the following values: "1", "2", "3", "4", or "5". +#' @return A string representing the class type corresponding to the input: +#' "factor", "integer", "numeric", "character", or "logical". +#' @noRd +.change_choice_to_string <- function(class_decision) { + case_when( + class_decision == "1" ~ "factor", + class_decision == "2" ~ "integer", + class_decision == "3" ~ "numeric", + class_decision == "4" ~ "character", + class_decision == "5" ~ "logical" + ) +} + +#' Print Factor Levels Recode Message +#' +#' Prints a message summarizing the factor level decisions that were made for variables with conflicts. +#' +#' @param unique_levels A list of unique factor levels. +#' @param newobj The name of the new DataFrame. +#' @importFrom cli cli_alert_success cli_alert_info +#' @return None. Prints the message to the console. +#' @noRd +.print_levels_recode_message <- function(unique_levels, newobj) { + levels_message <- .make_levels_recode_message(unique_levels) + cli_alert_success("The following levels have been set for all datasources in {newobj}: ") + for (i in 1:length(levels_message)) { + cli_alert_info("{levels_message[[i]]}") + } +} + +#' Make Levels Recode Message +#' +#' Creates a message to alert the user about factor level recoding. +#' +#' @param unique_levels A list of unique factor levels. +#' @return A formatted string summarizing the level recoding. +#' @importFrom purrr pmap +#' @noRd +.make_levels_recode_message <- function(unique_levels) { + return( + list(names(unique_levels), unique_levels) %>% + pmap(function(.x, .y) { + paste0(.x, " --> ", paste0(.y, collapse = ", ")) + }) + ) +} + +#' Handle Warnings for Class Conversion Issues +#' +#' This function iterates through a list of warnings generated during class conversion and +#' triggers a danger alert if any warnings indicate that the conversion has resulted in `NA` values. +#' +#' @param fill_warnings A list or vector of warning messages generated during class conversion. +#' If any warnings indicate that `NA` values were introduced, a danger alert will be displayed. +#' @return NULL. This function is used for its side effects of printing alerts. +#' @importFrom cli cli_alert_danger +#' @noRd +.handle_warnings <- function(fill_warnings) { + if (length(fill_warnings) > 0) { + for (i in seq_along(fill_warnings)) { + if (grepl("NAs introduced by coercion", fill_warnings[[i]])) { + cli_alert_danger("Class conversion resulted in the creation of NA values.") + } else { + cli_alert_danger(fill_warnings[[i]]) + } + } + } +} + +#' Print Warning for Class Conflicts in Data Conversion +#' +#' This function displays a warning when there are class conflicts in a dataset that may have resulted +#' from incompatible class changes during data conversion. It alerts users to verify column classes, +#' as incompatible changes could corrupt the data. +#' +#' @param class_conflicts A list or vector of conflicting classes identified during conversion. +#' @param fix_class A string indicating the user's choice for fixing class conflicts. Typically, +#' this is "ask" if the user is prompted to confirm class changes. +#' @param class_decisions A vector of decisions made for class conversions. When any value is not +#' "6", it indicates unresolved class conflicts. +#' @return NULL. This function is used for its side effects of printing alerts. +#' @importFrom cli cli_alert_warning +#' @noRd +.print_class_warning <- function(class_conflicts, fix_class, class_decisions) { + if(length(class_conflicts) > 0 & fix_class == "ask" & all(!class_decisions == "6")) { + cli_alert_warning("Please check all columns that have changed class. Not all class changes + are compatible with all data types, so this could have corrupted the data.") + } +} + +readline <- NULL diff --git a/man/ds.standardiseDf.Rd b/man/ds.standardiseDf.Rd new file mode 100644 index 00000000..4f544f1d --- /dev/null +++ b/man/ds.standardiseDf.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/ds.standardiseDf.R +\name{ds.standardiseDf} +\alias{ds.standardiseDf} +\title{Fill DataFrame with Missing Columns and Adjust Classes} +\usage{ +ds.standardiseDf( + df.name = NULL, + newobj = NULL, + fix_class = "ask", + fix_levels = "ask", + datasources = NULL +) +} +\arguments{ +\item{df.name}{Name of the input DataFrame to fill.} + +\item{newobj}{Name of the new DataFrame object created after filling.} + +\item{fix_class}{Character, determines behaviour if class of variables is not the same in all +studies. Option "ask" (default) provides the user with a prompt asking if they want to set the +class across all studies, option "no" will throw an error if class conflicts are present.} + +\item{fix_levels}{Character, determines behaviour if levels of factor variables is not the same +in all studies. Option "ask" (default) provides the user with a prompt asking if they want to set +the levels of factor variables to be the same across all studies, whilst option "no" will throw +an error if factor variables do not have the same class.} + +\item{datasources}{Data sources from which to aggregate data. Default is `NULL`.} +} +\value{ +The filled DataFrame with added columns and adjusted classes or factor levels. +} +\description{ +This function fills a given DataFrame by adding missing columns, ensuring consistent column classes, and adjusting factor levels where necessary. +It performs checks to detect class and factor level conflicts and prompts the user for decisions to resolve these conflicts. +} diff --git a/tests/testthat/_snaps/smk-standardiseDf.md b/tests/testthat/_snaps/smk-standardiseDf.md new file mode 100644 index 00000000..ac1558bd --- /dev/null +++ b/tests/testthat/_snaps/smk-standardiseDf.md @@ -0,0 +1,86 @@ +# ask_question displays the correct prompt + + Code + ask_question_class("my_var") + Message + i Would you like to: + 1. Convert `my_var` to a factor in all studies + 2. Convert `my_var` to an integer in all studies + 3. Convert `my_var` to numeric in all studies + 4. Convert `my_var` to a character in all studies + 5. Convert `my_var` to a logical vector in all studies + 6. Cancel `ds.dataFrameFill` operation + +# print_all_classes prints the correct message + + Code + print_all_classes(c("server_1", "server_2", "server_3"), c("numeric", "factor", + "integer")) + Message + * server_1: numeric + * server_2: factor + * server_3: integer + +# .make_levels_message makes correct message + + Code + .make_levels_message(level_conflicts) + Message + ! Warning: factor variables fac_col2, fac_col3, fac_col6, and fac_col9 do not have the same levels in all studies + i Would you like to: + 1. Create the missing levels where they are not present + 2. Do nothing + 3. Cancel `ds.dataFrameFill` operation + +# .print_var_recode_message prints the correct message + + Code + .print_var_recode_message(added_cols, "test_df") + Message + v The following variables have been added to test_df: + i server_1 --> col11 + i server_2 --> col11 + i server_3 --> col12 + + +# .print_class_recode_message prints the correct message + + Code + .print_class_recode_message(class_decisions, different_classes, "test_df") + Message + v The following classes have been set for all datasources in test_df: + i fac_col4 --> factor + i fac_col5 --> logical + +# .print_levels_recode_message prints the correct message + + Code + .print_levels_recode_message(unique_levs, "test_df") + Message + v The following levels have been set for all datasources in test_df: + i fac_col2 --> Blue, Green, Red + i fac_col3 --> No, Yes + i fac_col6 --> Bird, Cat, Dog + i fac_col9 --> False, True + +# .print_out_messages prints the correct messages + + Code + .print_out_messages(added_cols, class_decisions, different_classes, unique_levs, + level_conflicts, "1", "test_df") + Message + v The following variables have been added to test_df: + i server_1 --> col11 + i server_2 --> col11 + i server_3 --> col12 + + v The following classes have been set for all datasources in test_df: + i fac_col4 --> factor + i fac_col5 --> logical + + v The following levels have been set for all datasources in test_df: + i fac_col2 --> Blue, Green, Red + i fac_col3 --> No, Yes + i fac_col6 --> Bird, Cat, Dog + i fac_col9 --> False, True + diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R new file mode 100644 index 00000000..919df923 --- /dev/null +++ b/tests/testthat/helpers.R @@ -0,0 +1,218 @@ +#' Create a DSLite login object that can be used for testing +#' +#' @param assign_method A string specifying the name of the custom assign method to be added +#' to the DSLite server. If `NULL`, no additional assign method is added. Default is `NULL`. +#' @param aggregate_method A string specifying the name of the custom aggregate method to be +#' added to the DSLite server. If `NULL`, no additional aggregate method is added. Default is `NULL`. +#' @param tables A named list of tables to be made available on the DSLite server. Default is `NULL`. +#' +#' @return A DataSHIELD login object containing the necessary connection information for the DSLite server. +#' +#' @examples +#' \dontrun{ +#' # Prepare a DSLite server with default methods and custom assign/aggregate methods +#' login_data <- .prepare_dslite( +#' assign_method = "customAssign", +#' aggregate_method = "customAggregate", +#' tables = list(mtcars = mtcars, mtcars_group = mtcars_group) +#' ) +#' +#' @importFrom DSLite newDSLiteServer +#' @importFrom DSI newDSLoginBuilder +#' @export +.prepare_dslite <- function(assign_method = NULL, aggregate_method = NULL, tables = NULL) { + + options(datashield.env = environment()) + dslite.server <- DSLite::newDSLiteServer(tables = tables) + dslite.server$config(defaultDSConfiguration(include = c("dsBase", "dsTidyverse"))) + dslite.server$aggregateMethod("exists", "base::exists") + dslite.server$aggregateMethod("classDS", "dsBase::classDS") + dslite.server$aggregateMethod("lsDS", "dsBase::lsDS") + dslite.server$aggregateMethod("dsListDisclosureSettings", "dsTidyverse::dsListDisclosureSettings") + + if (!is.null(assign_method)) { + dslite.server$assignMethod(assign_method, paste0("dsTidyverse::", assign_method)) + } + + if (!is.null(aggregate_method)) { + dslite.server$aggregateMethod(assign_method, paste0("dsTidyverse::", assign_method)) + } + + builder <- DSI::newDSLoginBuilder() + builder$append(server = "server_1", url = "dslite.server", driver = "DSLiteDriver") + builder$append(server = "server_2", url = "dslite.server", driver = "DSLiteDriver") + builder$append(server = "server_3", url = "dslite.server", driver = "DSLiteDriver") + login_data <- builder$build() + return(login_data) +} + +#' Create a mixed dataframe with factor and other types of columns +#' +#' This function generates a dataframe with a specified number of rows, +#' factor columns, and other columns (integer, numeric, and string). +#' +#' @param n_rows Number of rows in the dataframe. Default is 10,000. +#' @param n_factor_cols Number of factor columns in the dataframe. Default is 15. +#' @param n_other_cols Number of other columns (integer, numeric, and string) in the dataframe. Default is 15. +#' +#' @return A dataframe with the specified number of rows and columns, containing mixed data types. +#' @importFrom dplyr bind_cols +#' @importFrom purrr map_dfc +#' @examples +#' df <- create_mixed_dataframe(n_rows = 100, n_factor_cols = 10, n_other_cols = 5) +create_mixed_dataframe <- function(n_rows = 10000, n_factor_cols = 15, n_other_cols = 15) { + + # Function to create a factor column with defined levels + create_factor_column <- function(levels, n = n_rows) { + set.seed(123) # Set seed before sample for reproducibility + factor(sample(levels, n, replace = TRUE)) + } + + # Define factor levels for different columns + factor_levels <- list( + c("Low", "Medium", "High"), + c("Red", "Green", "Blue"), + c("Yes", "No"), + c("A", "B", "C"), + c("One", "Two", "Three"), + c("Cat", "Dog", "Bird"), + c("Small", "Medium", "Large"), + c("Alpha", "Beta", "Gamma"), + c("True", "False"), + c("Left", "Right"), + c("North", "South", "East", "West"), + c("Day", "Night"), + c("Up", "Down"), + c("Male", "Female"), + c("Summer", "Winter", "Spring", "Fall") + ) + + # Create factor columns + factor_columns <- map_dfc(factor_levels[1:n_factor_cols], create_factor_column) + colnames(factor_columns) <- paste0("fac_col", 1:n_factor_cols) + + # Function to create other types of columns + create_other_column <- function(type, n = n_rows) { + set.seed(123) # Set seed before sample for reproducibility + switch(type, + "int" = sample(1:100, n, replace = TRUE), # Integer column + "num" = runif(n, 0, 100), # Numeric column + "str" = sample(letters, n, replace = TRUE), # Character column + "log" = sample(c(TRUE, FALSE), n, replace = TRUE) # Logical column + ) + } + + # Ensure that each data type is included + column_types <- c( + "int", "num", "str", "log", "int", + "num", "str", "log", "int", "num", + "str", "int", "num", "log", "str" + ) + + # Create other columns with specified types + other_columns <- map_dfc(column_types[1:n_other_cols], create_other_column) + colnames(other_columns) <- paste0("col", (n_factor_cols + 1):(n_factor_cols + n_other_cols)) + + # Combine factor and other columns into a single dataframe + df <- bind_cols(factor_columns, other_columns) + + return(df) +} + + +#' Modify factor levels for partial overlap +#' +#' This function takes two sets of factor levels, computes the common and unique levels, +#' and returns a new set of levels with partial overlap. +#' +#' @param levels1 First set of factor levels. +#' @param levels2 Second set of factor levels. +#' +#' @return A character vector of new factor levels with partial overlap. +#' @examples +#' new_levels <- partial_overlap_levels(c("A", "B", "C"), c("B", "C", "D")) +partial_overlap_levels <- function(levels1, levels2) { + common <- intersect(levels1, levels2) + unique1 <- setdiff(levels1, common) + unique2 <- setdiff(levels2, common) + + # Set seed before each sample call + set.seed(123) + sampled_unique1 <- sample(unique1, length(unique1) * 0.5) + + set.seed(123) + sampled_unique2 <- sample(unique2, length(unique2) * 0.5) + + new_levels <- c(common, sampled_unique1, sampled_unique2) + return(new_levels) +} + + +#' Create additional dataframes with specific conditions +#' +#' This function generates additional dataframes based on an input dataframe, modifying column classes and levels, +#' and adding new columns with unique names. Different seeds are used for each iteration of the loop, +#' ensuring reproducibility of the generated dataframes. +#' +#' @param base_df The base dataframe used to create the additional dataframes. +#' @param n_rows Number of rows in the additional dataframes. Default is 10,000. +#' @param df_names Names of the additional dataframes to be created. Default is c("df1", "df2", "df3"). +#' +#' @return A list of dataframes with the specified modifications. +#' @importFrom dplyr bind_cols +#' @examples +#' base_df <- create_mixed_dataframe(n_rows = 100, n_factor_cols = 10, n_other_cols = 5) +#' additional_dfs <- create_additional_dataframes(base_df, n_rows = 1000, df_names = c("df1", "df2")) +create_additional_dataframes <- function(base_df, n_rows = 10000, df_names = c("df1", "df2", "df3")) { + + # Define a fixed sequence of seeds, one for each dataframe to be created + seeds <- c(123, 456, 789, 101112) + + df_list <- list() + + for (i in seq_along(df_names)) { + # Set the seed for this iteration based on the pre-defined seeds + set.seed(seeds[i]) + + # Proceed with the dataframe generation process + overlap_cols <- sample(colnames(base_df), size = round(0.8 * ncol(base_df))) + df <- base_df + cols_to_modify_class <- sample(overlap_cols, size = round(0.2 * length(overlap_cols))) + + # Modify columns to have different data types + for (col in cols_to_modify_class) { + current_class <- class(df[[col]]) + new_class <- switch(current_class, + "factor" = as.character(df[[col]]), + "character" = as.factor(df[[col]]), + "numeric" = as.integer(df[[col]]), + "integer" = as.numeric(df[[col]]), + df[[col]]) + df[[col]] <- new_class + } + + # Modify factor levels for partial overlap + factor_cols <- colnames(base_df)[sapply(base_df, is.factor)] + overlap_factor_cols <- intersect(overlap_cols, factor_cols) + cols_to_modify_levels <- sample(overlap_factor_cols, size = round(0.5 * length(overlap_factor_cols))) + + for (col in cols_to_modify_levels) { + original_levels <- levels(base_df[[col]]) + new_levels <- partial_overlap_levels(original_levels, original_levels) + df[[col]] <- factor(df[[col]], levels = new_levels) + } + + # Create new random columns for each dataframe (these will vary by seed) + set.seed(seeds[i]) # Set the seed again for generating new columns + n_new_cols <- round(0.2 * ncol(base_df)) + new_col_names <- paste0(df_names[i], "_new_col_", 1:n_new_cols) + new_cols <- data.frame(matrix(runif(n_rows * n_new_cols), ncol = n_new_cols)) + colnames(new_cols) <- new_col_names + + # Bind new columns to the dataframe + df <- bind_cols(df, new_cols) + df_list[[df_names[i]]] <- df + } + + return(df_list) +} diff --git a/tests/testthat/test-smk-standardiseDf.R b/tests/testthat/test-smk-standardiseDf.R new file mode 100644 index 00000000..fca136ac --- /dev/null +++ b/tests/testthat/test-smk-standardiseDf.R @@ -0,0 +1,746 @@ +suppressWarnings(library(DSLite)) +library(purrr) +library(dplyr) +# devtools::install_github("datashield/dsBase", ref = "v6.4.0-dev") +library(dsBase) +library(dsBaseClient) +library(purrr) +library(dsTidyverse) +source("~/Library/Mobile Documents/com~apple~CloudDocs/work/repos/dsBaseClient/tests/testthat/helpers.R") +# devtools::load_all("~/Library/Mobile Documents/com~apple~CloudDocs/work/repos/dsTidyverse") +options("datashield.return_errors" = TRUE) +testthat::local_edition(3) + +df <- create_mixed_dataframe(n_rows = 100, n_factor_cols = 10, n_other_cols = 10) + +df_1 <- df %>% select(1:5, 6, 9, 12, 15, 18) %>% + mutate( + fac_col2 = factor(fac_col2, levels = c("Blue", "Green")), + fac_col4 = as.numeric(fac_col4), + fac_col5 = as.logical(fac_col5)) + +df_2 <- df %>% select(1:5, 7, 10, 13, 16, 19) %>% + mutate( + fac_col2 = factor(fac_col2, levels = c("Green", "Red")), + fac_col3 = factor(fac_col3, levels = "No"), + fac_col4 = as.character(fac_col4), + fac_col5 = as.integer(fac_col5)) + +df_3 <- df %>% select(1:5, 11, 14, 17, 20) %>% + mutate( + fac_col2 = factor(fac_col2, levels = "Blue"), + fac_col3 = factor(fac_col3, levels = "Yes")) + +options(datashield.env = environment()) + +dslite.server <- newDSLiteServer( + tables = list( + df_1 = df_1, + df_2 = df_2, + df_3 = df_3 + ) +) + +dslite.server$config(defaultDSConfiguration(include = c("dsBase", "dsTidyverse", "dsDanger"))) +dslite.server$aggregateMethod("getClassAllColsDS", "getClassAllColsDS") +dslite.server$assignMethod("fixClassDS", "fixClassDS") +dslite.server$assignMethod("fixColsDS", "fixColsDS") +dslite.server$aggregateMethod("getAllLevelsDS", "getAllLevelsDS") +dslite.server$assignMethod("fixLevelsDS", "fixLevelsDS") + +builder <- DSI::newDSLoginBuilder() + +builder$append( + server = "server_1", + url = "dslite.server", + driver = "DSLiteDriver" +) + +builder$append( + server = "server_2", + url = "dslite.server", + driver = "DSLiteDriver" +) + +builder$append( + server = "server_3", + url = "dslite.server", + driver = "DSLiteDriver" +) + +logindata <- builder$build() +conns <- DSI::datashield.login(logins = logindata, assign = FALSE) + +datashield.assign.table(conns["server_1"], "df", "df_1") +datashield.assign.table(conns["server_2"], "df", "df_2") +datashield.assign.table(conns["server_3"], "df", "df_3") + +datashield.assign.table(conns["server_1"], "df_ident", "df_1") +datashield.assign.table(conns["server_2"], "df_ident", "df_1") +datashield.assign.table(conns["server_3"], "df_ident", "df_1") + +#################################################################################################### +# Code that will be used in multiple tests +#################################################################################################### +var_class <- .get_var_classes("df", datasources = conns) + +class_conflicts <- .identify_class_conflicts(var_class) + +different_classes <- c("fac_col4", "fac_col5") + +class_decisions <- c("1", "5") + +.fix_classes( + df.name = "df", + different_classes = different_classes, + class_decisions = class_decisions, + newobj = "new_classes", + datasources = conns) + +cols_to_set <- c( + "fac_col1", "fac_col2", "fac_col3", "fac_col4", "fac_col5", "fac_col6", "fac_col9", "col12", + "col15", "col18", "fac_col7", "fac_col10", "col13", "col16", "col19", "col11", "col14", "col17", + "col20") + +.add_missing_cols_to_df( + df.name = "df", + cols_to_add_if_missing = cols_to_set, + newobj = "with_new_cols", + datasources = conns) + +old_cols <- ds.colnames("df") + +new_cols <- c("col11", "col12", "col13", "col14", "col15", "col16", "col17", "col18", "col19", + "col20", "fac_col1", "fac_col10", "fac_col2", "fac_col3", "fac_col4", "fac_col5", + "fac_col6", "fac_col7", "fac_col9") + +new_cols_servers <- list( + server_1 = new_cols, + server_2 = new_cols, + server_3 = new_cols +) + +added_cols <- .get_added_cols(old_cols, new_cols_servers) + +var_class_fact <- .get_var_classes("with_new_cols", datasources = conns) + +fac_vars <- .identify_factor_vars(var_class_fact) + +fac_levels <- .get_factor_levels(fac_vars, "with_new_cols", conns) + +level_conflicts <- .identify_level_conflicts(fac_levels) + +unique_levs <- .get_unique_levels(fac_levels, level_conflicts) + +#################################################################################################### +# Tests +#################################################################################################### +test_that(".stop_if_cols_identical throws error if columns are identical", { + + identical_cols <- list( + c("col1", "col2", "col3"), + c("col1", "col2", "col3"), + c("col1", "col2", "col3") + ) + + expect_error( + .stop_if_cols_identical(identical_cols), + "Columns are identical in all data frames: nothing to fill" + ) + +}) + +test_that(".stop_if_cols_identical doesn't throw error if data frames have different columns", { + + different_cols <- list( + c("col1", "col2", "col3"), + c("col1", "col2", "col4"), + c("col1", "col5", "col3") + ) + + expect_silent( + .stop_if_cols_identical(different_cols) + ) + +}) + +test_that(".get_var_classes returns correct output", { + + expected <- tibble( + server = c("server_1", "server_2", "server_3"), + fac_col1 = c("factor", "factor", "factor"), + fac_col2 = c("factor", "factor", "factor"), + fac_col3 = c("factor", "factor", "factor"), + fac_col4 = c("numeric", "character", "factor"), + fac_col5 = c("logical", "integer", "factor"), + fac_col6 = c("factor", NA, NA), + fac_col9 = c("factor", NA, NA), + col12 = c("numeric", NA, NA), + col15 = c("integer", NA, NA), + col18 = c("logical", NA, NA), + fac_col7 = c(NA, "factor", NA), + fac_col10 = c(NA, "factor", NA), + col13 = c(NA, "character", NA), + col16 = c(NA, "numeric", NA), + col19 = c(NA, "integer", NA), + col11 = c(NA, NA, "integer"), + col14 = c(NA, NA, "logical"), + col17 = c(NA, NA, "character"), + col20 = c(NA, NA, "numeric") + ) + + expect_equal(var_class, expected) + +}) + +test_that(".identify_class_conflicts returns correct output", { + expected <- list( + fac_col4 = c("numeric", "character", "factor"), + fac_col5 = c("logical", "integer", "factor") + ) + + expect_equal(class_conflicts, expected) + +}) + +test_that("ask_question displays the correct prompt", { + expect_snapshot(ask_question_class("my_var")) +}) + +test_that("ask_question_wait_response_class continues with valid response", { + expect_equal( + with_mocked_bindings( + ask_question_wait_response_class("a variable"), + ask_question_class = function(var) "A question", + readline = function() "1" + ), "1" + ) +}) + +test_that("ask_question_wait_response_class throws error if option 6 selected", { + expect_error( + with_mocked_bindings( + ask_question_wait_response_class("a variable"), + ask_question_class = function(var) "A question", + readline = function() "6") + ) +}) + +test_that("print_all_classes prints the correct message", { + expect_snapshot( + print_all_classes( + c("server_1", "server_2", "server_3"), + c("numeric", "factor", "integer") + ) + ) +}) + +test_that("prompt_user_class_decision function properly", { + expect_message( + with_mocked_bindings( + prompt_user_class_decision( + var = "test_col", + servers = c("server_1", "server_2", "server_3"), + classes = c("numeric", "character", "factor"), + newobj = "test_df", + datasources = datasources), + ask_question_wait_response_class = function(var, newobj, datasources) "test_col" + ) + ) + + expect_equal( + with_mocked_bindings( + prompt_user_class_decision( + var = "test_col", + servers = c("server_1", "server_2", "server_3"), + classes = c("numeric", "character", "factor"), + newobj = "test_df", + datasources = datasources), + ask_question_wait_response_class = function(var, newobj, datasources) "test_col" + ), + "test_col" + ) +}) + +test_that("prompt_user_class_decision_all_vars returns correct value", { + expect_equal( + with_mocked_bindings( + prompt_user_class_decision_all_vars( + vars = c("test_var_1", "test_var_2"), + all_servers = c("server_1", "server_2", "server_3"), + all_classes = tibble( + test_var_1 = c("numeric", "character", "factor"), + test_var_2 = c("logical", "integer", "factor") + ), + "test_df", + conns), + prompt_user_class_decision = function(var, server, classes, newobj, datasources) "1" + ), + c("1", "1") + ) +}) + +test_that(".fix_classes sets the correct classes in serverside data frame", { + + expect_equal( + unname(unlist(ds.class("df$fac_col4"))), + c("numeric", "character", "factor") + ) + + expect_equal( + unname(unlist(ds.class("df$fac_col5"))), + c("logical", "integer", "factor") + ) + + expect_equal( + unname(unlist(ds.class("new_classes$fac_col4"))), + rep("factor", 3) + ) + + expect_equal( + unname(unlist(ds.class("new_classes$fac_col5"))), + rep("logical", 3) + ) + +}) + +test_that(".get_unique_cols extracts unique names from a list", { + expect_equal( + .get_unique_cols( + list( + server_1 = c("col_1", "col_2", "col_3"), + server_1 = c("col_1", "col_2", "col_4"), + server_1 = c("col_2", "col_3", "col_3", "col_5") + ) + ), + c("col_1", "col_2", "col_3", "col_4", "col_5") + ) +}) + +test_that(".add_missing_cols_to_df correctly creates missing columns", { + + new_cols <- c("col11", "col12", "col13", "col14", "col15", "col16", "col17", "col18", "col19", + "col20", "fac_col1", "fac_col10", "fac_col2", "fac_col3", "fac_col4", "fac_col5", + "fac_col6", "fac_col7", "fac_col9") + + observed <- ds.colnames("with_new_cols") + + expected <- list( + server_1 = new_cols, + server_2 = new_cols, + server_3 = new_cols + ) + + expect_equal(observed, expected) +}) + +test_that(".get_added_cols correctly identifies newly added columns", { + + expect_equal( + added_cols, + list( + server_1 = c("col11", "col13", "col14", "col16", "col17", "col19", "col20", "fac_col10", "fac_col7"), + server_2 = c("col11", "col12", "col14", "col15", "col17", "col18", "col20", "fac_col6", "fac_col9"), + server_3 = c("col12", "col13", "col15", "col16", "col18", "col19", "fac_col10", "fac_col6", "fac_col7", "fac_col9") + ) + ) +}) + +test_that(".identify_factor_vars correctly identifies factor variables", { + + + + var_class_fact <- var_class |> dplyr::select(server: col18) + expect_equal( + names(fac_vars), + c("fac_col1", "fac_col2", "fac_col3", "fac_col6", "fac_col9") + ) +}) + +test_that(".get_factor_levels correctly identifies factor levels", { + expected <- list( + server_1 = list( + fac_col1 = c("High", "Low", "Medium"), + fac_col2 = c("Blue", "Green"), + fac_col3 = c("No", "Yes"), + fac_col6 = c("Bird", "Cat", "Dog"), + fac_col9 = c("False", "True") + ), + server_2 = list( + fac_col1 = c("High", "Low", "Medium"), + fac_col2 = c("Green", "Red"), + fac_col3 = c("No"), + fac_col6 = NULL, + fac_col9 = NULL + ), + server_3 = list( + fac_col1 = c("High", "Low", "Medium"), + fac_col2 = c("Blue"), + fac_col3 = c("Yes"), + fac_col6 = NULL, + fac_col9 = NULL + ) + ) + + expect_equal(fac_levels, expected) +}) + +test_that(".identify_level_conflicts correctly factor columns with different levels", { + expect_equal( + .identify_level_conflicts(fac_levels), + c("fac_col2", "fac_col3", "fac_col6", "fac_col9") + ) + +}) + +test_that("ask_question_wait_response_levels continues with valid response", { + expect_equal( + with_mocked_bindings( + suppressWarnings(ask_question_wait_response_levels("test variable", "test_obj", conns)), + readline = function() "1" + ), "1" + ) + + expect_equal( + with_mocked_bindings( + suppressWarnings(ask_question_wait_response_levels("test variable", "test_obj", conns)), + readline = function() "1" + ), "1" + ) + +}) + +test_that("ask_question_wait_response_levels aborts with response of 3", { + expect_error( + with_mocked_bindings( + suppressWarnings(ask_question_wait_response_levels("test variable", "test_obj", conns)), + readline = function() "3") + ) +}) + +test_that(".make_levels_message makes correct message", { + expect_snapshot(.make_levels_message(level_conflicts)) +}) + +test_that(".get_unique_levels extracts all possible levels", { + + expected <- list( + fac_col2 = c("Blue", "Green", "Red"), + fac_col3 = c("No", "Yes"), + fac_col6 = c("Bird", "Cat", "Dog"), + fac_col9 = c("False", "True") + ) + + expect_equal(unique_levs, expected) + +}) + +test_that(".set_factor_levels sets levels correctly", { + .set_factor_levels("with_new_cols", unique_levs, conns) + + expect_equal( + ds.levels("with_new_cols$fac_col2") |> map(~.x[[1]]), + list( + server_1 = c("Blue", "Green", "Red"), + server_2 = c("Blue", "Green", "Red"), + server_3 = c("Blue", "Green", "Red") + ) + ) + + expect_equal( + ds.levels("with_new_cols$fac_col3") |> map(~.x[[1]]), + list( + server_1 = c("No", "Yes"), + server_2 = c("No", "Yes"), + server_3 = c("No", "Yes") + ) + ) + + expect_equal( + ds.levels("with_new_cols$fac_col6") |> map(~.x[[1]]), + list( + server_1 = c("Bird", "Cat", "Dog"), + server_2 = c("Bird", "Cat", "Dog"), + server_3 = c("Bird", "Cat", "Dog") + ) + ) + + expect_equal( + ds.levels("with_new_cols$fac_col9") |> map(~.x[[1]]), + list( + server_1 = c("False", "True"), + server_2 = c("False", "True"), + server_3 = c("False", "True") + ) + ) + +}) + +test_that(".print_var_recode_message prints the correct message", { + expect_snapshot(.print_var_recode_message(added_cols, "test_df")) +}) + +test_that(".print_class_recode_message prints the correct message", { + expect_snapshot( + .print_class_recode_message(class_decisions, different_classes, "test_df") + ) +}) + +test_that(".print_levels_recode_message prints the correct message", { + expect_snapshot( + .print_levels_recode_message(unique_levs, "test_df") + ) +}) + +test_that(".make_levels_recode_message prints the correct message", { + expect_equal( + .make_levels_recode_message(unique_levs), + list( + "fac_col2 --> Blue, Green, Red", + "fac_col3 --> No, Yes", + "fac_col6 --> Bird, Cat, Dog", + "fac_col9 --> False, True" + ) + ) +}) + +test_that(".print_out_messages prints the correct messages", { + expect_snapshot( + .print_out_messages( + added_cols, class_decisions, different_classes, unique_levs, level_conflicts, "1", "test_df" + ) + ) +}) + +test_that(".change_choice_to_string converts numeric class codes to strings correctly", { + expect_equal(.change_choice_to_string("1"), "factor") + expect_equal(.change_choice_to_string("2"), "integer") + expect_equal(.change_choice_to_string("3"), "numeric") + expect_equal(.change_choice_to_string("4"), "character") + expect_equal(.change_choice_to_string("5"), "logical") +}) + +test_that("ds.standardiseDf doesn't run if dataframes are identical", { + expect_error( + ds.standardiseDf( + df = "df_ident", + newobj = "test_fill" + ), + "Columns are identical" + ) +}) + +test_that("ds.standardiseDf works when called directly and class conversion is factor", { + with_mocked_bindings( + ds.standardiseDf( + df = "df", + newobj = "test_fill" + ), + prompt_user_class_decision_all_vars = function(var, server, classes, newobj, datasources) "1", + ask_question_wait_response_levels = function(levels_conflict, newobj, datasources) "2" + ) + + expect_equal( + ds.class("test_fill$fac_col4")[[1]], + "factor" + ) +}) + +test_that("ds.standardiseDf returns warning when called directly and class conversion is integer", { + with_mocked_bindings( + ds.standardiseDf( + df = "df", + newobj = "test_fill" + ), + prompt_user_class_decision_all_vars = function(var, server, classes, newobj, datasources) c("2", "2"), + ask_question_wait_response_levels = function(levels_conflict, newobj, datasources) "2" + ) + + expect_equal( + ds.class("test_fill$fac_col4")[[1]], + "integer" + ) + + expect_equal( + ds.class("test_fill$fac_col5")[[1]], + "integer" + ) +}) + +test_that("ds.standardiseDf returns warning when called directly and class conversion is numeric", { + with_mocked_bindings( + ds.standardiseDf( + df = "df", + newobj = "test_fill" + ), + prompt_user_class_decision_all_vars = function(var, server, classes, newobj, datasources) c("3", "3"), + ask_question_wait_response_levels = function(levels_conflict, newobj, datasources) "2" + ) + + expect_equal( + ds.class("test_fill$fac_col4")[[1]], + "numeric" + ) + + expect_equal( + ds.class("test_fill$fac_col5")[[1]], + "numeric" + ) +}) + +test_that("ds.standardiseDf returns warning when called directly and class conversion is character", { + with_mocked_bindings( + ds.standardiseDf( + df = "df", + newobj = "test_fill" + ), + prompt_user_class_decision_all_vars = function(var, server, classes, newobj, datasources) c("4", "4"), + ask_question_wait_response_levels = function(levels_conflict, newobj, datasources) "2" + ) + + expect_equal( + ds.class("test_fill$fac_col4")[[1]], + "character" + ) + + expect_equal( + ds.class("test_fill$fac_col5")[[1]], + "character" + ) +}) + +test_that("ds.standardiseDf returns warning when called directly and class conversion is logical", { + with_mocked_bindings( + ds.standardiseDf( + df = "df", + newobj = "test_fill" + ), + prompt_user_class_decision_all_vars = function(var, server, classes, newobj, datasources) c("5", "5"), + ask_question_wait_response_levels = function(levels_conflict, newobj, datasources) "2" + ) + + expect_equal( + ds.class("test_fill$fac_col4")[[1]], + "logical" + ) + + expect_equal( + ds.class("test_fill$fac_col5")[[1]], + "logical" + ) +}) + +test_that("ds.standardiseDf changes levels if this option is selected", { + with_mocked_bindings( + ds.standardiseDf( + df = "df", + newobj = "test_fill" + ), + prompt_user_class_decision_all_vars = function(var, server, classes, newobj, datasources) c("1", "1"), + ask_question_wait_response_levels = function(levels_conflict, newobj, datasources) "1" + ) + + levels_2 <- ds.levels("test_fill$fac_col2") %>% map(~.$Levels) + levels_3 <- ds.levels("test_fill$fac_col3") %>% map(~.$Levels) + levels_4 <- ds.levels("test_fill$fac_col4") %>% map(~.$Levels) + levels_5 <- ds.levels("test_fill$fac_col5") %>% map(~.$Levels) + levels_6 <- ds.levels("test_fill$fac_col6") %>% map(~.$Levels) + levels_9 <- ds.levels("test_fill$fac_col9") %>% map(~.$Levels) + + expect_equal( + levels_2, + list( + server_1 = c("Blue", "Green", "Red"), + server_2 = c("Blue", "Green", "Red"), + server_3 = c("Blue", "Green", "Red") + ) + ) + + expect_equal( + levels_3, + list( + server_1 = c("No", "Yes"), + server_2 = c("No", "Yes"), + server_3 = c("No", "Yes") + ) + ) + + expect_equal( + levels_4, + list( + server_1 = c("1", "2", "3", "A", "B", "C"), + server_2 = c("1", "2", "3", "A", "B", "C"), + server_3 = c("1", "2", "3", "A", "B", "C") + ) + ) + + expect_equal( + levels_5, + list( + server_1 = c("1", "2", "3", "One", "Three", "Two"), + server_2 = c("1", "2", "3", "One", "Three", "Two"), + server_3 = c("1", "2", "3", "One", "Three", "Two") + ) + ) + + expect_equal( + levels_6, + list( + server_1 = c("Bird", "Cat", "Dog"), + server_2 = c("Bird", "Cat", "Dog"), + server_3 = c("Bird", "Cat", "Dog") + ) + ) + + expect_equal( + levels_9, + list( + server_1 = c("False", "True"), + server_2 = c("False", "True"), + server_3 = c("False", "True") + ) + ) + +}) + +test_that("ds.standardiseDf doesn't run if classes are not identical and fix_class is no", { + expect_error( + ds.standardiseDf( + df = "df", + newobj = "shouldnt_exist", + fix_class = "no" + ), + "Variables do not have the same class in all studies" + ) + + expect_equal( + ds.exists("shouldnt_exist")[[1]], + FALSE + ) +}) + +test_that("ds.standardiseDf doesn't run if levels are not identical and fix_class is no", { + expect_error( + with_mocked_bindings( + ds.standardiseDf( + df = "df", + newobj = "shouldnt_exist", + fix_levels = "no" + ), + prompt_user_class_decision_all_vars = function(var, server, classes, newobj, datasources) c("1", "1") + ), + "Factor variables do not have the same levels in all studies" + ) + + expect_equal( + ds.exists("shouldnt_exist")[[1]], + FALSE + ) +}) + + +## 9. Handle incorrect response for level fix + + + +