diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..856943c --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +simulations/**/*.rds filter=lfs diff=lfs merge=lfs -text +simulations/**/*.log filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index 130a84f..3d1096d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ docs inst/doc /doc/ /Meta/ +simulations/**/*.rds +simulations/**/*.log diff --git a/Detailed_functions_explanation_for_software_developers.Rmd b/Detailed_functions_explanation_for_software_developers.Rmd index 824fc76..472eb49 100644 --- a/Detailed_functions_explanation_for_software_developers.Rmd +++ b/Detailed_functions_explanation_for_software_developers.Rmd @@ -3,7 +3,7 @@ title: "Detailed_functions_explanation_for_software_developers" output: html_document --- -In this file I give a detailed explanation of the logic followed in the functions of the bonsaiforest2 library +In this file I give a detailed explanation of the logic followed in the functions of the bonsaiforest2 library. # File: prepare_formula_model.R Purpose: Prepares a multi-part brms formula and modified data for downstream model fitting; @@ -20,7 +20,10 @@ Returns: list(formula, data, response_type). Start | v -prepare_formula_model(data, response_formula_str, 'Optional predictive/prognostic formulas', response_type) +prepare_formula_model(data, +| response_formula_str, +| 'Optional predictive/prognostic formulas', +| response_type) | +-- 1) Argument validation & parse initial formula | - Calls: .parse_initial_formula(...) @@ -147,3 +150,187 @@ prepare_formula_model() → orchestrator .add_missing_prognostic_effects() → ensure main effects for interactions .assemble_brms_formula() → build brms::bf(..., nl=TRUE) and combine lf() parts + + +# File: fit_brms_model.R +Purpose: Wraps brms::brm() to fit the multi-part model produced by prepare_formula_model(), building and assigning appropriate priors (strings or brmsprior objects), handling family selection, and returning a brmsfit. +Input: prepared_model, + predictive_effect_priors, + prognostic_effect_priors, + stanvars +Returns: fitted `brmsfit` object. + +Start +| +v +fit_brms_model(prepared_model, +| predictive_effect_priors, +| prognostic_effect_priors, +| stanvars, ...) +| ++-- 1) Argument validation & unpack +| - Validates prepared_model is a named list containing formula (class +| brmsformula), data (data.frame), and response_type (one of +| "binary","count","continuous","survival") +| - Validates prior lists (predictive_effect_priors, prognostic_effect_priors) | and stanvars +| - Unpacks: formula, data, response_type +| +v ++-- 2) Determine brms family +| - Maps response_type -> family: +| * "continuous" -> gaussian() +| * "binary" -> bernoulli(link = "logit") +| * "count" -> negbinomial() +| * "survival" -> cox() +| - Stops on invalid response_type +| +v ++-- 3) Construct prior configuration & process each prior target +| - Defines prior_config: a list of target records for +| * shrunk prognostic (nlpar = "shprogeffect", default "horseshoe(1)") +| * unshrunk prognostic (nlpar = "unprogeffect", default "normal(0,5)") +| * prognostic intercept (nlpar = "unprogeffect", coef = "Intercept", default | "normal(0,5)") — skipped for survival +| * shrunk predictive (nlpar = "shpredeffect", default "horseshoe(1)") +| * unshrunk predictive (nlpar = "unpredeffect", default "normal(0,10)") +| - Fetches which nlpars actually exist from names(formula$pforms) +| - For each configured target present in formula: +| - Calls .process_and_retarget_prior(user_prior, +| target_nlpar, +| default_str, +| target_class, +| target_coef) +| - Collects returned brmsprior objects in priors_list +| - Tracks which defaults were used to message the user +| +v ++-- 4) .process_and_retarget_prior() helper logic (see below) +| +v ++-- 5) Combine priors & prepare final prior object +| - If any priors were created: final_priors <- Reduce("+", priors_list) +| - Else: final_priors <- brms::empty_prior() +| - Prints message listing defaults used (if any) +| +v ++-- 6) Call brms::brm() +| - Arguments: formula, +| data, +| family = model_family, +| prior = final_priors, +| stanvars, ... (other brm args) +| - Returns fitted brmsfit object +| +v +Return brmsfit + +### Key Details & Notes + +- Validation and safety: Uses checkmate to assert structure and types for prepared_model, prior lists, and stanvars. Requires prepared_model$formula be a brmsformula and prepared_model$data a non-empty data.frame. + +- Family mapping: A simple switch() maps response_type to brms family constructors; invalid types error out. + +- Prior configuration: The code only applies priors to nlpars present in formula$pforms (so unused components are ignored). Intercept prior is skipped for response_type == "survival" because survival non-linear parts have no intercept. Defaults are used when user does not provide a prior for a target; defaults are logged via message. + +- Prior processing behavior (.process_and_retarget_prior): +Accepts user_prior that may be: + * NULL → function uses default_str and flags default_used = TRUE. + * character string → wraps into a brmsprior via brms::set_prior(prior = "...", nlpar = target_nlpar, [class], [coef]). + * brmsprior object → retargets rows with empty nlpar to the target_nlpar and + (optionally) sets class and coef columns for those rows; logs an informative message about retargeting. + +If user_prior is a brmsprior, the helper only modifies rows where nlpar is blank (so already-targeted priors remain untouched). +The helper returns a list: list(prior = , default_used = ). +If user_prior is none of the accepted types, it errors. +Combining priors: +priors_list elements are combined with Reduce("+", priors_list) to produce a single brmsprior object accepted by brms::brm(). +If no priors defined, brms::empty_prior() is used. + +- stanvars support: stanvars (class stanvars) are accepted and forwarded to brms::brm() unchanged, enabling custom Stan code inclusion. + +- ... passthrough: All extra arguments are forwarded to brms::brm(), so the caller controls sampling parameters (chains, iter, warmup, backend, etc.). + +- Messaging: The function emits messages listing default priors that were used and retargeting actions when brmsprior objects are adjusted. + +- Return value: Returns the brmsfit object returned by brms::brm() for downstream use (posterior prediction, diagnostics, etc.). + +### Quick mapping: major internal helper + +fit_brms_model() → orchestrator; validates input, chooses family, builds priors, calls brms::brm() + +.process_and_retarget_prior() → converts/retargets user-specified priors; returns brmsprior + flag indicating if default was used + +# File: run_brms_analysis.R +Purpose: High-level user-facing wrapper that prepares formula/data via prepare_formula_model() and fits the model via fit_brms_model(), returning a brmsfit. +Input: data, + response_formula_str, + shrunk_predictive_formula_str, + unshrunk_prognostic_formula_str, + unshrunk_predictive_formula_str, + shrunk_prognostic_formula_str, + response_type, + stratification_formula_str, + predictive_effect_priors, + prognostic_effect_priors, + stanvars +Returns: fitted `brmsfit` object. + +Start +| +v +run_brms_analysis(data, +| response_formula_str, +| response_type, ..., +| prognostic_effect_priors, +| predictive_effect_priors, +| stanvars, ...) +| ++-- 1) Prepare formula & data +| - Calls: prepare_formula_model(data, +| response_formula_str, +| shrunk_predictive_formula_str, +| unshrunk_prognostic_formula_str, +| unshrunk_predictive_formula_str, +| shrunk_prognostic_formula_str, +| response_type, +| stratification_formula_str) +| - Side-effects / outputs: +| - Messages: "Step 1: Preparing formula and data..." +| - Returns prepared_model (list with formula, data, response_type) +| +v ++-- 2) Fit the Bayesian model +| - Messages: "Step 2: Fitting the brms model..." +| - Calls: fit_brms_model(prepared_model = prepared_model, +| predictive_effect_priors = predictive_effect_priors, +| prognostic_effect_priors = prognostic_effect_priors, +| stanvars = stanvars, ...) +| - Passes through user ... args (chains, iter, cores, backend, etc.) +| - fit_brms_model() handles validation, family selection, prior construction, and brms::brm() call +| +v ++-- 3) Return final model +| - Messages: "Analysis complete." +| - Returns: fitted brmsfit object from fit_brms_model() +| +v +End + +### Key Details & Notes + +- run_brms_analysis() is a thin orchestrator: it does not validate priors itself; it delegates data/formula prep to prepare_formula_model() and prior handling + fitting to fit_brms_model(). + +- Defaults: predictive_effect_priors and prognostic_effect_priors default to empty lists; fit_brms_model() will apply defaults where needed and message which defaults were used. + +-Passthrough: All additional arguments in ... are forwarded to fit_brms_model() and ultimately to brms::brm(), giving callers full control over sampling and performance options. + +- Messages: The function prints three concise progress messages to help track the run steps. + +- Return contract: Always returns the brmsfit object produced by fit_brms_model(); callers should inspect/validate convergence and diagnostics themselves. + +###Quick mapping: major collaborators +run_brms_analysis() → orchestrator/wrapper + +prepare_formula_model() → builds multi-part brmsformula, prepares/encodes data + +fit_brms_model() → validates prepared object, constructs priors, calls brms::brm(), returns brmsfit + diff --git a/R/estimate_subgroup_effects.R b/R/estimate_subgroup_effects.R index 636a9aa..4b1e828 100644 --- a/R/estimate_subgroup_effects.R +++ b/R/estimate_subgroup_effects.R @@ -57,9 +57,9 @@ estimate_subgroup_effects <- function(brms_fit, choices = c("continuous", "binary", "count", "survival")) checkmate::assert_count(ndraws, null.ok = TRUE, positive = TRUE) - # CRITICAL FIX: Use brms_fit$data (processed data with contrasts) for all operations - # This ensures consistency between predictions and subgroup membership - # The original_data parameter is kept for backward compatibility but we use model data + # Use the processed data attached to the fitted model. + # This data contains the factor levels and contrasts actually used during fitting, + # ensuring predictions and subgroup assignments are consistent with the model. model_data <- brms_fit$data # Validate that model_data and original_data are compatible (same number of rows) @@ -68,7 +68,7 @@ estimate_subgroup_effects <- function(brms_fit, "Using the data from brms_fit object for consistency.") } - # --- 2. Create counterfactual datasets (UPDATED for Factor Logic) --- + # --- 2. Create counterfactual datasets (treatment/control counterfactuals) --- # Use model data (with contrasts) for everything message("Step 1: Creating counterfactual datasets...") @@ -85,7 +85,7 @@ estimate_subgroup_effects <- function(brms_fit, trt_var = trt_var ) - # --- 3. Generate posterior predictions --- + # --- 3. Generate posterior predictions (expected outcomes or survival components) --- message("Step 2: Generating posterior predictions...") posterior_preds <- .get_posterior_predictions( brms_fit = brms_fit, @@ -96,7 +96,7 @@ estimate_subgroup_effects <- function(brms_fit, ndraws = ndraws ) - # --- 4. Calculate marginal effects --- + # --- 4. Calculate marginal effects and summarize posterior draws --- message("Step 3: Calculating marginal effects...") results <- .calculate_and_summarize_effects( posterior_preds = posterior_preds, @@ -110,84 +110,181 @@ estimate_subgroup_effects <- function(brms_fit, return(results) } -#' Prepare and Validate Subgroup Variables (Formula Parsing) +#' Calculate Average Hazard Ratio (vectorized) #' -#' Validates subgroup inputs. If "auto", it scans the brms formula object -#' for interaction terms involving the treatment variable (e.g. `trt:subgroup`). +#' Compute per-draw average hazard ratio (AHR) from marginal survival curves. +#' The function accepts matrices of survival probabilities for control and +#' treatment arms where rows correspond to posterior draws and columns to time +#' points, and returns a numeric vector of AHR draws. +#' @noRd +.calculate_ahr_vectorized <- function(S_control, S_treatment) { + n_draws <- nrow(S_control) + ahr_draws <- numeric(n_draws) + for (i in 1:n_draws) { + sc <- S_control[i, ] + st <- S_treatment[i, ] + dsc <- -diff(c(1, sc)) + dst <- -diff(c(1, st)) + + if(length(dsc) < length(sc)) dsc <- c(dsc, 0) + if(length(dst) < length(st)) dst <- c(dst, 0) + + num <- sum(sc * dst, na.rm=TRUE) + den <- sum(st * dsc, na.rm=TRUE) + ahr_draws[i] <- if (den == 0) NA else num/den + } + return(ahr_draws) +} + + +#' Calculate and summarize marginal subgroup effects #' +#' Given posterior predictions and subgroup membership, this helper computes +#' the marginal effect draws for each subgroup (or overall), then summarizes +#' them with robust posterior summaries (median and 95% interval). +#' This function supports continuous, binary, count and survival outcomes. #' @noRd -.prepare_subgroup_vars <- function(brms_fit, original_data, trt_var, subgroup_vars) { +.calculate_and_summarize_effects <- function(posterior_preds, original_data, subgroup_vars, is_overall, response_type) { - # Logic to handle "auto" or specific columns - checkmate::assert( - checkmate::check_string(subgroup_vars, pattern = "^auto$"), - checkmate::check_character(subgroup_vars, null.ok = TRUE, min.len = 1, unique = TRUE) - ) + all_results_list <- list() + all_draws_list <- list() - is_overall <- FALSE + # Pre-compute factor info for iteration + subgroup_factors <- list() + if (!is_overall) { + for (var in subgroup_vars) { + factor_var <- as.factor(original_data[[var]]) + subgroup_factors[[var]] <- list(factor = factor_var, levels = levels(factor_var)) + } + } - if (is.null(subgroup_vars)) { - is_overall <- TRUE - subgroup_vars <- "Overall" - } else if (identical(subgroup_vars, "auto")) { - message("`subgroup_vars` set to 'auto'. Detecting from model data...") + for (current_subgroup_var in subgroup_vars) { + if (!is_overall) message(paste("... processing", current_subgroup_var)) - # With the new explicit dummy approach, interaction dummies follow pattern: trt_VARNAMElevel - # We look for columns in model_data that match this pattern - dummy_pattern <- paste0("^", trt_var, "_") - interaction_cols <- grep(dummy_pattern, names(original_data), value = TRUE) + if (is_overall) { + current_data_subgroups <- as.factor(rep("Overall", nrow(original_data))) + subgroup_levels <- "Overall" + } else { + current_data_subgroups <- subgroup_factors[[current_subgroup_var]]$factor + subgroup_levels <- subgroup_factors[[current_subgroup_var]]$levels + } - if (length(interaction_cols) > 0) { - # Extract variable names by removing the trt_ prefix and the level suffix - # We check which factor variables in the data could have produced these dummies - detected_vars <- character(0) + level_results_list <- list() - for (var_name in names(original_data)) { - if (is.factor(original_data[[var_name]]) && var_name != trt_var) { - # Check if any dummy matches this variable - var_levels <- levels(original_data[[var_name]]) - # Use make.names to match the sanitized names created in prepare_formula_model - expected_dummies <- make.names(paste0(trt_var, "_", var_name, var_levels), unique = FALSE) - if (any(expected_dummies %in% interaction_cols)) { - detected_vars <- c(detected_vars, var_name) - } + for (level in subgroup_levels) { + subgroup_indices <- which(current_data_subgroups == level) + + if (response_type == "survival") { + effect_draws <- .calculate_survival_ahr_draws( + linpred_control = posterior_preds$linpred_control, + linpred_treatment = posterior_preds$linpred_treatment, + H0_posterior_list = posterior_preds$H0_posterior, + indices = subgroup_indices, + strat_var = posterior_preds$strat_var, + original_data = posterior_preds$original_data + ) + } else { + # For non-survival outcomes compute marginal (averaged) predictions per draw. + # Predictions are returned as matrices with rows = draws and cols = observations. + marginal_outcome_control <- if (length(subgroup_indices) == 1) { + posterior_preds$pred_control[, subgroup_indices] + } else { + rowMeans(posterior_preds$pred_control[, subgroup_indices, drop = FALSE]) } + + marginal_outcome_treatment <- if (length(subgroup_indices) == 1) { + posterior_preds$pred_treatment[, subgroup_indices] + } else { + rowMeans(posterior_preds$pred_treatment[, subgroup_indices, drop = FALSE]) + } + + effect_draws <- switch( + response_type, + continuous = marginal_outcome_treatment - marginal_outcome_control, + binary = qlogis(marginal_outcome_treatment) - qlogis(marginal_outcome_control), + count = marginal_outcome_treatment / marginal_outcome_control + ) } - detected_vars <- unique(detected_vars) - } else { - detected_vars <- character(0) - } - if (length(detected_vars) == 0) { - message("...no specific interaction terms detected. Calculating overall effect.") - is_overall <- TRUE - subgroup_vars <- "Overall" - } else { - subgroup_vars <- detected_vars - message(paste("...detected subgroup variable(s):", paste(subgroup_vars, collapse = ", "))) + subgroup_name <- if (is_overall) "Overall" else paste0(current_subgroup_var, ": ", level) + all_draws_list[[subgroup_name]] <- effect_draws + + point_estimate <- median(effect_draws, na.rm = TRUE) + ci <- quantile(effect_draws, probs = c(0.025, 0.975), na.rm = TRUE) + + level_results_list[[level]] <- tibble::tibble( + Subgroup = subgroup_name, + Median = point_estimate, + CI_Lower = ci[1], + CI_Upper = ci[2] + ) } - } else { - # Validate user provided vars exist - checkmate::assert_subset(subgroup_vars, names(original_data)) + all_results_list[[current_subgroup_var]] <- dplyr::bind_rows(level_results_list) } - if (is_overall) { - original_data$Overall <- "Overall" - } + final_results <- dplyr::bind_rows(all_results_list) + draws_df <- dplyr::bind_cols(all_draws_list) - return(list( - subgroup_vars = subgroup_vars, - data = original_data, - is_overall = is_overall - )) + return(list(estimates = final_results, draws = draws_df)) } -#' Create Counterfactual Datasets (With Explicit Dummy Recreation) + +#' Calculate Average Hazard Ratio (AHR) Draws for Survival Models #' -#' Creates "all control" and "all treatment" datasets. -#' Importantly, this function also recreates any interaction dummy variables -#' (e.g., trt_regionA, trt_regionB) that were created by prepare_formula_model. +#' For survival outcomes, compute per-draw AHRs by: +#' - reconstructing marginal survival curves for control and treatment, +#' - converting these to per-draw individual-level AHRs, +#' - averaging across individuals within the subgroup. +#' The implementation processes draws in chunks to limit memory usage. +#' @noRd +.calculate_survival_ahr_draws <- function(linpred_control, linpred_treatment, H0_posterior_list, indices, strat_var, original_data) { + # --- Assertions --- + checkmate::assert_matrix(linpred_control) + checkmate::assert_matrix(linpred_treatment) + checkmate::assert_list(H0_posterior_list, names = "named") + checkmate::assert_integerish(indices, min.len = 1, unique = TRUE) + checkmate::assert_string(strat_var, null.ok = TRUE) + checkmate::assert_data_frame(original_data) + + # Check dimensions + if (nrow(linpred_control) != nrow(linpred_treatment)) { + stop("linpred_control and linpred_treatment must have the same number of rows (draws).") + } + + n_draws <- nrow(linpred_control) + chunk_size <- min(1000, n_draws) + n_chunks <- ceiling(n_draws / chunk_size) + + ahr_draws <- numeric(n_draws) + + for (chunk in 1:n_chunks) { + start_idx <- (chunk - 1) * chunk_size + 1 + end_idx <- min(chunk * chunk_size, n_draws) + chunk_indices <- start_idx:end_idx + + S_control_chunk <- .get_marginal_survival_vectorized( + linpred_control[chunk_indices, , drop = FALSE], + H0_posterior_list, indices, strat_var, original_data + ) + S_treatment_chunk <- .get_marginal_survival_vectorized( + linpred_treatment[chunk_indices, , drop = FALSE], + H0_posterior_list, indices, strat_var, original_data + ) + + ahr_draws[chunk_indices] <- .calculate_ahr_vectorized(S_control_chunk, S_treatment_chunk) + } + + return(ahr_draws) +} + +#' Create treatment and control counterfactual datasets #' +#' Produce two datasets identical to the model data but with the treatment +#' variable set to the reference (control) or active (treatment) level for +#' all rows. Any explicit interaction dummy variables created during +#' preprocessing (pattern: trt_var_subgroupLEVEL) are updated to reflect the +#' chosen treatment level so predictions using explicit interaction columns +#' remain consistent. #' @noRd .create_counterfactual_datasets <- function(model_data, trt_var) { checkmate::assert_data_frame(model_data) @@ -195,8 +292,8 @@ estimate_subgroup_effects <- function(brms_fit, # Ensure treatment is a factor to get levels if (!is.factor(model_data[[trt_var]])) { - # Fallback if it was somehow converted to int, though prepare_formula handles this - model_data[[trt_var]] <- as.factor(model_data[[trt_var]]) + # Fallback if it was somehow converted to int, though prepare_formula handles this + model_data[[trt_var]] <- as.factor(model_data[[trt_var]]) } trt_levels <- levels(model_data[[trt_var]]) @@ -208,29 +305,28 @@ estimate_subgroup_effects <- function(brms_fit, ref_level <- trt_levels[1] # Usually 0 alt_level <- trt_levels[2] # Usually 1 - # Identify interaction dummy columns (pattern: trt_varname) - # These are created by .process_predictive_terms as trt_regionA, trt_regionB, etc. + # Identify explicit interaction dummy columns (pattern: trt_var_) interaction_dummy_pattern <- paste0("^", trt_var, "_") interaction_cols <- grep(interaction_dummy_pattern, names(model_data), value = TRUE) - # Create Control (Reference) Data + # Create control (reference) dataset data_control <- model_data data_control[[trt_var]] <- factor(rep(ref_level, nrow(model_data)), levels = trt_levels) contrasts(data_control[[trt_var]]) <- trt_contrasts - # Set all interaction dummies to 0 (since trt = control) + # Ensure interaction dummies are 0 in the control dataset for (col in interaction_cols) { data_control[[col]] <- 0 } - # Create Treatment Data + # Create treatment dataset and populate interaction dummies data_treatment <- model_data data_treatment[[trt_var]] <- factor(rep(alt_level, nrow(model_data)), levels = trt_levels) contrasts(data_treatment[[trt_var]]) <- trt_contrasts - # Recreate interaction dummies for treatment arm - # Pattern: trt_subgroupvarLEVEL, e.g., trt_regionA - # Value should be 1 if patient is in that subgroup level, 0 otherwise + # Recreate interaction dummies for treatment arm. For each explicit dummy + # created during preprocessing set the dummy to 1 when the patient belongs + # to the corresponding subgroup level (and treatment is active), otherwise 0. for (col in interaction_cols) { # Find which variable this corresponds to by checking which factor variable # in model_data has this level @@ -253,9 +349,8 @@ estimate_subgroup_effects <- function(brms_fit, if (matched) break } } - - # If no match found, the dummy column might be from a variable that's not a factor - # in the current data. Keep it at 0 for safety. + + # If no matching factor level is found, set the dummy to 0 for safety. if (!matched) { warning(paste("Could not match interaction column", col, "to any factor variable. Setting to 0.")) data_treatment[[col]] <- 0 @@ -273,11 +368,73 @@ estimate_subgroup_effects <- function(brms_fit, return(list(control = data_control, treatment = data_treatment)) } -#' Get Posterior Predictions (Robust) +#' Extract baseline hazard posterior predictions #' -#' Generates predictions. Handles Survival (manual bhaz reconstruction) -#' and Standard (posterior_epred) models. +#' Parses the fitted brms formula to identify the survival time and status +#' variables, reconstructs the spline basis used for the baseline hazard, +#' and returns the posterior baseline hazard evaluated at observed event times. +#' @noRd +.extract_baseline_hazard <- function(brms_fit, original_data, ndraws) { + # Parse formula for bhaz + lhs_formula_str_vec <- deparse(brms_fit$formula$formula[[2]]) + bhaz_term <- paste(lhs_formula_str_vec, collapse = " ") + + resp_match <- stringr::str_match(bhaz_term, "(\\w+)\\s*\\|\\s*cens\\(1\\s*-\\s*(\\w+)\\)") + if (is.na(resp_match[1,1])) stop("Could not parse 'time | cens(1 - status)' structure.") + + time_var <- resp_match[1, 2] + status_var <- resp_match[1, 3] + + strat_match <- stringr::str_match(bhaz_term, "gr\\s*=\\s*(\\w+)") + strat_var <- if (!is.na(strat_match[1, 2])) strat_match[1, 2] else NULL + + bknots_str <- stringr::str_extract(bhaz_term, "Boundary\\.knots = c\\(.*?\\)") + knot_str <- stringr::str_extract(bhaz_term, "(? 0) { + sbhaz_matrix_level <- as.matrix(sbhaz_draws_df[, sorted_cols]) + H0_posterior_list[[level]] <- as.matrix(i_spline_basis %*% t(sbhaz_matrix_level)) + } + } + } + + return(list(H0_posterior = H0_posterior_list, strat_var = strat_var)) +} + +#' Obtain posterior predictions for control and treatment counterfactuals #' +#' For non-survival responses the function returns expected outcomes from +#' `posterior_epred`. For survival responses it returns linear predictors +#' plus a reconstruction of the baseline hazard posterior so that marginal +#' survival curves can be computed outside brms. #' @noRd .get_posterior_predictions <- function(brms_fit, data_control, data_treatment, response_type, original_data, ndraws = NULL) { @@ -290,7 +447,7 @@ estimate_subgroup_effects <- function(brms_fit, data_combined <- rbind(data_control, data_treatment) if (response_type == "survival") { - message("... (reconstructing baseline hazard and getting linear predictors)...") + message("... (reconstructing baseline hazard and obtaining linear predictors)...") # 1. Get Linear Predictors (eta) for combined data # brms handles the trt:subgroup interaction here automatically @@ -304,8 +461,7 @@ estimate_subgroup_effects <- function(brms_fit, linpred_control <- linpred_combined[, 1:n_control] linpred_treatment <- linpred_combined[, (n_control + 1):ncol(linpred_combined)] - # 2. Reconstruct Baseline Hazard - # (Reuse the robust logic from your previous code) + # 2. Reconstruct baseline hazard posterior h0_res <- .extract_baseline_hazard(brms_fit, original_data, ndraws) return(list( @@ -336,203 +492,16 @@ estimate_subgroup_effects <- function(brms_fit, } } -#' Helper: Extract Baseline Hazard details for Survival -#' -#' Separated for cleanliness. Parses formula to find bhaz() terms. -#' @noRd -.extract_baseline_hazard <- function(brms_fit, original_data, ndraws) { - # Parse formula for bhaz - lhs_formula_str_vec <- deparse(brms_fit$formula$formula[[2]]) - bhaz_term <- paste(lhs_formula_str_vec, collapse = " ") - - resp_match <- stringr::str_match(bhaz_term, "(\\w+)\\s*\\|\\s*cens\\(1\\s*-\\s*(\\w+)\\)") - if (is.na(resp_match[1,1])) stop("Could not parse 'time | cens(1 - status)' structure.") - - time_var <- resp_match[1, 2] - status_var <- resp_match[1, 3] - strat_match <- stringr::str_match(bhaz_term, "gr\\s*=\\s*(\\w+)") - strat_var <- if (!is.na(strat_match[1, 2])) strat_match[1, 2] else NULL - - bknots_str <- stringr::str_extract(bhaz_term, "Boundary\\.knots = c\\(.*?\\)") - knot_str <- stringr::str_extract(bhaz_term, "(? 0) { - sbhaz_matrix_level <- as.matrix(sbhaz_draws_df[, sorted_cols]) - H0_posterior_list[[level]] <- as.matrix(i_spline_basis %*% t(sbhaz_matrix_level)) - } - } - } - - return(list(H0_posterior = H0_posterior_list, strat_var = strat_var)) -} - - -#' Calculate and Summarize Marginal Effects (Identical Logic to Previous) -#' @noRd -.calculate_and_summarize_effects <- function(posterior_preds, original_data, subgroup_vars, is_overall, response_type) { - - all_results_list <- list() - all_draws_list <- list() - - # Pre-compute factor info for iteration - subgroup_factors <- list() - if (!is_overall) { - for (var in subgroup_vars) { - factor_var <- as.factor(original_data[[var]]) - subgroup_factors[[var]] <- list(factor = factor_var, levels = levels(factor_var)) - } - } - - for (current_subgroup_var in subgroup_vars) { - if (!is_overall) message(paste("... processing", current_subgroup_var)) - - if (is_overall) { - current_data_subgroups <- as.factor(rep("Overall", nrow(original_data))) - subgroup_levels <- "Overall" - } else { - current_data_subgroups <- subgroup_factors[[current_subgroup_var]]$factor - subgroup_levels <- subgroup_factors[[current_subgroup_var]]$levels - } - - level_results_list <- list() - - for (level in subgroup_levels) { - subgroup_indices <- which(current_data_subgroups == level) - - if (response_type == "survival") { - effect_draws <- .calculate_survival_ahr_draws( - linpred_control = posterior_preds$linpred_control, - linpred_treatment = posterior_preds$linpred_treatment, - H0_posterior_list = posterior_preds$H0_posterior, - indices = subgroup_indices, - strat_var = posterior_preds$strat_var, - original_data = posterior_preds$original_data - ) - } else { - # Vectorized means - # Check dimensions: rows=draws, cols=obs - marginal_outcome_control <- if (length(subgroup_indices) == 1) { - posterior_preds$pred_control[, subgroup_indices] - } else { - rowMeans(posterior_preds$pred_control[, subgroup_indices, drop = FALSE]) - } - - marginal_outcome_treatment <- if (length(subgroup_indices) == 1) { - posterior_preds$pred_treatment[, subgroup_indices] - } else { - rowMeans(posterior_preds$pred_treatment[, subgroup_indices, drop = FALSE]) - } - - effect_draws <- switch( - response_type, - continuous = marginal_outcome_treatment - marginal_outcome_control, - binary = qlogis(marginal_outcome_treatment) - qlogis(marginal_outcome_control), - count = marginal_outcome_treatment / marginal_outcome_control - ) - } - - subgroup_name <- if (is_overall) "Overall" else paste0(current_subgroup_var, ": ", level) - all_draws_list[[subgroup_name]] <- effect_draws - - point_estimate <- median(effect_draws, na.rm = TRUE) - ci <- quantile(effect_draws, probs = c(0.025, 0.975), na.rm = TRUE) - - level_results_list[[level]] <- tibble::tibble( - Subgroup = subgroup_name, - Median = point_estimate, - CI_Lower = ci[1], - CI_Upper = ci[2] - ) - } - all_results_list[[current_subgroup_var]] <- dplyr::bind_rows(level_results_list) - } - - final_results <- dplyr::bind_rows(all_results_list) - draws_df <- dplyr::bind_cols(all_draws_list) - - return(list(estimates = final_results, draws = draws_df)) -} - - -#' Calculate Average Hazard Ratio (AHR) Draws for Survival Models -#' (Unchanged from your logic, included for completeness) -#' @noRd -.calculate_survival_ahr_draws <- function(linpred_control, linpred_treatment, H0_posterior_list, indices, strat_var, original_data) { - # --- Assertions --- - checkmate::assert_matrix(linpred_control) - checkmate::assert_matrix(linpred_treatment) - checkmate::assert_list(H0_posterior_list, names = "named") - checkmate::assert_integerish(indices, min.len = 1, unique = TRUE) - checkmate::assert_string(strat_var, null.ok = TRUE) - checkmate::assert_data_frame(original_data) - - # Check dimensions - if (nrow(linpred_control) != nrow(linpred_treatment)) { - stop("linpred_control and linpred_treatment must have the same number of rows (draws).") - } - - # OPTIMIZATION: Process in chunks to reduce memory usage for large datasets - n_draws <- nrow(linpred_control) - chunk_size <- min(1000, n_draws) - n_chunks <- ceiling(n_draws / chunk_size) - - ahr_draws <- numeric(n_draws) - - for (chunk in 1:n_chunks) { - start_idx <- (chunk - 1) * chunk_size + 1 - end_idx <- min(chunk * chunk_size, n_draws) - chunk_indices <- start_idx:end_idx - - # Get survival curves for this chunk - S_control_chunk <- .get_marginal_survival_vectorized( - linpred_control[chunk_indices, , drop = FALSE], - H0_posterior_list, indices, strat_var, original_data - ) - S_treatment_chunk <- .get_marginal_survival_vectorized( - linpred_treatment[chunk_indices, , drop = FALSE], - H0_posterior_list, indices, strat_var, original_data - ) - - # Calculate AHR for this chunk using vectorized operations - ahr_draws[chunk_indices] <- .calculate_ahr_vectorized(S_control_chunk, S_treatment_chunk) - } - - return(ahr_draws) -} - -#' Calculate Vectorized Marginal Survival Curve +#' Compute marginal survival curves for a subgroup (vectorized) +#' +#' Given linear predictors for a set of individuals (subset), and a list of +#' posterior baseline hazards (possibly stratified), compute the subgroup-level +#' marginal survival curve for each posterior draw. The result is a matrix with +#' rows = draws and columns = time points. #' @noRd .get_marginal_survival_vectorized <- function(linpred_posterior, H0_post_list, sub_indices, strat_variable, full_data) { - # ... [Same as your previous implementation] ... subgroup_linpred <- linpred_posterior[, sub_indices, drop = FALSE] n_draws <- nrow(subgroup_linpred) @@ -565,27 +534,71 @@ estimate_subgroup_effects <- function(brms_fit, return(S_marginal) } -#' Calculate AHR Vectorized + +#' Determine subgroup variables and prepare data for summarization +#' +#' Validates `subgroup_vars`. If set to `"auto"` the function inspects the +#' preprocessed model data for explicit interaction dummy columns (pattern: +#' `trt_var_level`) and infers which factor variables were used as subgroup +#' modifiers. When no interactions are found, it prepares the data for an +#' overall marginal effect. #' @noRd -.calculate_ahr_vectorized <- function(S_control, S_treatment) { - # ... [Same as your previous implementation] ... - n_draws <- nrow(S_control) - ahr_draws <- numeric(n_draws) - for (i in 1:n_draws) { - sc <- S_control[i, ] - st <- S_treatment[i, ] - dsc <- -diff(c(1, sc)) - dst <- -diff(c(1, st)) +.prepare_subgroup_vars <- function(brms_fit, original_data, trt_var, subgroup_vars) { - # Pad - if(length(dsc) < length(sc)) dsc <- c(dsc, 0) - if(length(dst) < length(st)) dst <- c(dst, 0) + # Logic to handle "auto" or specific columns + checkmate::assert( + checkmate::check_string(subgroup_vars, pattern = "^auto$"), + checkmate::check_character(subgroup_vars, null.ok = TRUE, min.len = 1, unique = TRUE) + ) - num <- sum(sc * dst, na.rm=TRUE) - den <- sum(st * dsc, na.rm=TRUE) - ahr_draws[i] <- if (den == 0) NA else num/den + is_overall <- FALSE + + if (is.null(subgroup_vars)) { + is_overall <- TRUE + subgroup_vars <- "Overall" + } else if (identical(subgroup_vars, "auto")) { + message("`subgroup_vars` set to 'auto'. Detecting from model data...") + + dummy_pattern <- paste0("^", trt_var, "_") + interaction_cols <- grep(dummy_pattern, names(original_data), value = TRUE) + + if (length(interaction_cols) > 0) { + detected_vars <- character(0) + for (var_name in names(original_data)) { + if (is.factor(original_data[[var_name]]) && var_name != trt_var) { + var_levels <- levels(original_data[[var_name]]) + expected_dummies <- make.names(paste0(trt_var, "_", var_name, var_levels), unique = FALSE) + if (any(expected_dummies %in% interaction_cols)) { + detected_vars <- c(detected_vars, var_name) + } + } + } + detected_vars <- unique(detected_vars) + } else { + detected_vars <- character(0) + } + + if (length(detected_vars) == 0) { + message("...no specific interaction terms detected. Calculating overall effect.") + is_overall <- TRUE + subgroup_vars <- "Overall" + } else { + subgroup_vars <- detected_vars + message(paste("...detected subgroup variable(s):", paste(subgroup_vars, collapse = ", "))) + } + } else { + # Validate user provided vars exist + checkmate::assert_subset(subgroup_vars, names(original_data)) + } + + if (is_overall) { + original_data$Overall <- "Overall" } - return(ahr_draws) -} + return(list( + subgroup_vars = subgroup_vars, + data = original_data, + is_overall = is_overall + )) +} diff --git a/R/fit_brms_model.R b/R/fit_brms_model.R index 243726a..713a282 100644 --- a/R/fit_brms_model.R +++ b/R/fit_brms_model.R @@ -75,8 +75,7 @@ fit_brms_model <- function(prepared_model, stanvars = NULL, ...) { - # --- 1. Argument Validation (Validate-First) --- - + # --- 1. Argument Validation --- # 1a. Validate the container and its structure checkmate::assert_list(prepared_model, names = "named", .var.name = "prepared_model") checkmate::assert_names( @@ -96,7 +95,7 @@ fit_brms_model <- function(prepared_model, checkmate::assert_list(prognostic_effect_priors, names = "named", null.ok = TRUE) checkmate::assert_class(stanvars, "stanvars", null.ok = TRUE) - # --- Unpack --- + # Unpack formula <- prepared_model$formula data <- prepared_model$data response_type <- prepared_model$response_type @@ -114,35 +113,33 @@ fit_brms_model <- function(prepared_model, } # --- 3. Construct the Prior List --- - # Define all possible prior components - # KEY CHANGE: The Intercept is class 'b' with coef 'Intercept' prior_config <- list( - # Shrunk Prognostic (b) - No intercept by definition ( ~ 0 + ...) + # Shrunk Prognostic (b) list(nlpar = "shprogeffect", class = "b", coef = NULL, user_prior = prognostic_effect_priors$shrunk, default = "horseshoe(1)", label = "shrunk prognostic (b)"), - # Unshrunk Prognostic (b) - NON-intercepts + # Unshrunk Prognostic (b) list(nlpar = "unprogeffect", class = "b", coef = NULL, user_prior = prognostic_effect_priors$unshrunk, default = "normal(0, 5)", label = "unshrunk prognostic (b)"), # Unshrunk Prognostic (Intercept) - list(nlpar = "unprogeffect", class = "b", coef = "Intercept", # <-- CHANGED + list(nlpar = "unprogeffect", class = "b", coef = "Intercept", user_prior = prognostic_effect_priors$intercept, default = "normal(0, 5)", label = "prognostic intercept"), - # Shrunk Predictive (b) - No intercept by definition + # Shrunk Predictive (b) list(nlpar = "shpredeffect", class = "b", coef = NULL, user_prior = predictive_effect_priors$shrunk, default = "horseshoe(1)", label = "shrunk predictive (b)"), - # Unshrunk Predictive (b) - No intercept by definition + # Unshrunk Predictive (b) list(nlpar = "unpredeffect", class = "b", coef = NULL, user_prior = predictive_effect_priors$unshrunk, default = "normal(0, 10)", @@ -159,7 +156,6 @@ fit_brms_model <- function(prepared_model, # Special case: Intercept prior only relevant if nlpar is unprogeffect # AND the response type is NOT survival (which has no intercept) - # CHANGED: Check conf$coef now, not conf$class if (!is.null(conf$coef) && conf$coef == "Intercept" && response_type == "survival") { next # Skip intercept prior for survival models } @@ -169,7 +165,7 @@ fit_brms_model <- function(prepared_model, target_nlpar = conf$nlpar, default_str = conf$default, target_class = conf$class, - target_coef = conf$coef # <-- PASSING NEW ARG + target_coef = conf$coef ) priors_list <- c(priors_list, list(processed$prior)) @@ -227,7 +223,7 @@ fit_brms_model <- function(prepared_model, #' @noRd .process_and_retarget_prior <- function(user_prior, target_nlpar, default_str, target_class = NULL, target_coef = NULL) { - # --- Assertions for helper function --- + # Assertions for helper function checkmate::assert_string(target_nlpar, min.chars = 1) checkmate::assert_string(default_str, min.chars = 1) checkmate::assert_string(target_class, null.ok = TRUE) @@ -244,7 +240,6 @@ fit_brms_model <- function(prepared_model, } if (is.character(prior_to_use)) { - # --- THIS BLOCK IS REVISED --- # Build a list of arguments, excluding NULLs args <- list( prior = prior_to_use, @@ -290,7 +285,6 @@ fit_brms_model <- function(prepared_model, } } return(list(prior = modified_prior, default_used = FALSE)) - # --- END REVISION --- } stop(paste("Prior for", target_nlpar, "must be NULL, a string, or a brmsprior object."), call. = FALSE) diff --git a/R/run_brms_analysis.R b/R/run_brms_analysis.R index 7a6380f..f3b0e77 100644 --- a/R/run_brms_analysis.R +++ b/R/run_brms_analysis.R @@ -5,7 +5,7 @@ #' #' This function is the main user-facing entry point. It first calls #' `prepare_formula_model` to build the `brmsformula` and process the data, -#' then passes the results to `fit_brms_model` to run the analysis. +#' then passes the results to `fit_brms_model` to fit the model. #' #' @param data A data.frame containing all the necessary variables. #' @param response_formula_str A character string for the response part, e.g., @@ -54,8 +54,6 @@ #' sim_data$subgroup <- as.factor(sim_data$subgroup) #' #' # 2. Run the full analysis -#' # We use \dontrun{} because fitting a model requires Stan compilation -#' # which may fail in automated CI/CD environments. #' \dontrun{ #' full_fit <- run_brms_analysis( #' data = sim_data, @@ -107,15 +105,14 @@ run_brms_analysis <- function(data, # --- 2. Fit the Bayesian Model --- message("\nStep 2: Fitting the brms model...") - # --- THIS IS THE MODIFIED PART --- model_fit <- fit_brms_model( - prepared_model = prepared_model, # Pass the entire list + prepared_model = prepared_model, predictive_effect_priors = predictive_effect_priors, prognostic_effect_priors = prognostic_effect_priors, stanvars = stanvars, ... ) - # --- END OF MODIFICATION --- + # --- 3. Return the Final Model --- message("\nAnalysis complete.") diff --git a/simulations/.gitattributes b/simulations/.gitattributes new file mode 100644 index 0000000..856943c --- /dev/null +++ b/simulations/.gitattributes @@ -0,0 +1,2 @@ +simulations/**/*.rds filter=lfs diff=lfs merge=lfs -text +simulations/**/*.log filter=lfs diff=lfs merge=lfs -text diff --git a/simulations/.gitignore b/simulations/.gitignore new file mode 100644 index 0000000..e69de29