Skip to content

Commit 6affae5

Browse files
authored
Merge pull request #179 from StochasticTree/make_rfx_parameters_flexible
Allow users to set random effects prior parameters in BART and BCF
2 parents 78cc4c6 + 0726617 commit 6affae5

19 files changed

+1121
-57
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ export(createRandomEffectSamples)
3838
export(createRandomEffectsDataset)
3939
export(createRandomEffectsModel)
4040
export(createRandomEffectsTracker)
41+
export(expand_dims_1d)
42+
export(expand_dims_2d)
43+
export(expand_dims_2d_diag)
4144
export(getRandomEffectSamples)
4245
export(loadForestContainerCombinedJson)
4346
export(loadForestContainerCombinedJsonString)

R/bart.R

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@
4545
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`.
4646
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
4747
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
48+
#' - `rfx_working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
49+
#' - `rfx_group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
50+
#' - `rfx_working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
51+
#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
52+
#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
53+
#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
4854
#'
4955
#' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional.
5056
#'
@@ -118,7 +124,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
118124
variable_weights = NULL, random_seed = -1,
119125
keep_burnin = FALSE, keep_gfr = FALSE, keep_every = 1,
120126
num_chains = 1, verbose = FALSE,
121-
probit_outcome_model = FALSE
127+
probit_outcome_model = FALSE,
128+
rfx_working_parameter_prior_mean = NULL,
129+
rfx_group_parameter_prior_mean = NULL,
130+
rfx_working_parameter_prior_cov = NULL,
131+
rfx_group_parameter_prior_cov = NULL,
132+
rfx_variance_prior_shape = 1,
133+
rfx_variance_prior_scale = 1
122134
)
123135
general_params_updated <- preprocessParams(
124136
general_params_default, general_params
@@ -168,6 +180,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
168180
num_chains <- general_params_updated$num_chains
169181
verbose <- general_params_updated$verbose
170182
probit_outcome_model <- general_params_updated$probit_outcome_model
183+
rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean
184+
rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean
185+
rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov
186+
rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov
187+
rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape
188+
rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale
171189

172190
# 2. Mean forest parameters
173191
num_trees_mean <- mean_forest_params_updated$num_trees
@@ -673,18 +691,38 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
673691
# Random effects initialization
674692
if (has_rfx) {
675693
# Prior parameters
676-
if (num_rfx_components == 1) {
677-
alpha_init <- c(1)
678-
} else if (num_rfx_components > 1) {
679-
alpha_init <- c(1,rep(0,num_rfx_components-1))
694+
if (is.null(rfx_working_parameter_prior_mean)) {
695+
if (num_rfx_components == 1) {
696+
alpha_init <- c(1)
697+
} else if (num_rfx_components > 1) {
698+
alpha_init <- c(1,rep(0,num_rfx_components-1))
699+
} else {
700+
stop("There must be at least 1 random effect component")
701+
}
702+
} else {
703+
alpha_init <- expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components)
704+
}
705+
706+
if (is.null(rfx_group_parameter_prior_mean)) {
707+
xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups)
708+
} else {
709+
xi_init <- expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups)
710+
}
711+
712+
if (is.null(rfx_working_parameter_prior_cov)) {
713+
sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
714+
} else {
715+
sigma_alpha_init <- expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components)
716+
}
717+
718+
if (is.null(rfx_group_parameter_prior_cov)) {
719+
sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components)
680720
} else {
681-
stop("There must be at least 1 random effect component")
721+
sigma_xi_init <- expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components)
682722
}
683-
xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups)
684-
sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
685-
sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components)
686-
sigma_xi_shape <- 1
687-
sigma_xi_scale <- 1
723+
724+
sigma_xi_shape <- rfx_variance_prior_shape
725+
sigma_xi_scale <- rfx_variance_prior_scale
688726

689727
# Random effects data structure and storage container
690728
rfx_dataset_train <- createRandomEffectsDataset(rfx_group_ids_train, rfx_basis_train)

R/bcf.R

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`.
4848
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
4949
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
50+
#' - `rfx_working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
51+
#' - `rfx_group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector.
52+
#' - `rfx_working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
53+
#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
54+
#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
55+
#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
5056
#'
5157
#' @param prognostic_forest_params (Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional.
5258
#'
@@ -162,7 +168,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
162168
treated_coding_init = 0.5, rfx_prior_var = NULL,
163169
random_seed = -1, keep_burnin = FALSE, keep_gfr = FALSE,
164170
keep_every = 1, num_chains = 1, verbose = FALSE,
165-
probit_outcome_model = FALSE
171+
probit_outcome_model = FALSE,
172+
rfx_working_parameter_prior_mean = NULL,
173+
rfx_group_parameter_prior_mean = NULL,
174+
rfx_working_parameter_prior_cov = NULL,
175+
rfx_group_parameter_prior_cov = NULL,
176+
rfx_variance_prior_shape = 1,
177+
rfx_variance_prior_scale = 1
166178
)
167179
general_params_updated <- preprocessParams(
168180
general_params_default, general_params
@@ -230,6 +242,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
230242
num_chains <- general_params_updated$num_chains
231243
verbose <- general_params_updated$verbose
232244
probit_outcome_model <- general_params_updated$probit_outcome_model
245+
rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean
246+
rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean
247+
rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov
248+
rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov
249+
rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape
250+
rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale
233251

234252
# 2. Mu forest parameters
235253
num_trees_mu <- prognostic_forest_params_updated$num_trees
@@ -842,24 +860,39 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
842860

843861
# Random effects prior parameters
844862
if (has_rfx) {
845-
# Initialize the working parameter to 1
846-
if (num_rfx_components < 1) {
847-
stop("There must be at least 1 random effect component")
863+
# Prior parameters
864+
if (is.null(rfx_working_parameter_prior_mean)) {
865+
if (num_rfx_components == 1) {
866+
alpha_init <- c(1)
867+
} else if (num_rfx_components > 1) {
868+
alpha_init <- c(1,rep(0,num_rfx_components-1))
869+
} else {
870+
stop("There must be at least 1 random effect component")
871+
}
872+
} else {
873+
alpha_init <- expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components)
874+
}
875+
876+
if (is.null(rfx_group_parameter_prior_mean)) {
877+
xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups)
878+
} else {
879+
xi_init <- expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups)
848880
}
849-
alpha_init <- rep(1,num_rfx_components)
850-
# Initialize each group parameter based on a regression of outcome on basis in that grou
851-
xi_init <- matrix(0,num_rfx_components,num_rfx_groups)
852-
for (i in 1:num_rfx_groups) {
853-
group_subset_indices <- rfx_group_ids_train == i
854-
basis_group <- rfx_basis_train[group_subset_indices,]
855-
resid_group <- resid_train[group_subset_indices]
856-
rfx_group_model <- lm(resid_group ~ 0+basis_group)
857-
xi_init[,i] <- unname(coef(rfx_group_model))
881+
882+
if (is.null(rfx_working_parameter_prior_cov)) {
883+
sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
884+
} else {
885+
sigma_alpha_init <- expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components)
886+
}
887+
888+
if (is.null(rfx_group_parameter_prior_cov)) {
889+
sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components)
890+
} else {
891+
sigma_xi_init <- expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components)
858892
}
859-
sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
860-
sigma_xi_init <- diag(rfx_prior_var)
861-
sigma_xi_shape <- 1
862-
sigma_xi_scale <- 1
893+
894+
sigma_xi_shape <- rfx_variance_prior_shape
895+
sigma_xi_scale <- rfx_variance_prior_scale
863896
}
864897

865898
# Random effects data structure and storage container

R/utils.R

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,3 +855,86 @@ orderedCatPreprocess <- function(x_input, unique_levels, var_name = NULL) {
855855
}
856856
return(x_preprocessed)
857857
}
858+
859+
#' Convert scalar input to vector of dimension `output_size`,
860+
#' or check that input array is equivalent to a vector of dimension `output_size`.
861+
#'
862+
#' @param input Input to be converted to a vector (or passed through as-is)
863+
#' @param output_size Intended size of the output vector
864+
#' @return A vector of length `output_size`
865+
#' @export
866+
expand_dims_1d <- function(input, output_size) {
867+
if (length(input) == 1) {
868+
output <- rep(input, output_size)
869+
} else if (is.numeric(input)) {
870+
if (length(input) != output_size) {
871+
stop("`input` must be a 1D numpy array with `output_size` elements")
872+
}
873+
output <- input
874+
} else {
875+
stop("`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times")
876+
}
877+
return(output)
878+
}
879+
880+
#' Ensures that input is propagated appropriately to a matrix of dimension `output_rows` x `output_cols`.
881+
#' Handles the following cases:
882+
#' 1. `input` is a scalar: output is simply a (`output_rows`, `output_cols`) matrix with `input` repeated for each element
883+
#' 2. `input` is a vector of length `output_rows`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_cols` columns
884+
#' 3. `input` is a vector of length `output_cols`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_rows` rows
885+
#' 4. `input` is a matrix of dimension (`output_rows`, `output_cols`): input is passed through as-is
886+
#' All other cases throw an error.
887+
#'
888+
#' @param input Input to be converted to a matrix (or passed through as-is)
889+
#' @param output_rows Intended number of rows in the output array
890+
#' @param output_cols Intended number of columns in the output array
891+
#' @return A matrix of dimension `output_rows` x `output_cols`
892+
#' @export
893+
expand_dims_2d <- function(input, output_rows, output_cols) {
894+
if (length(input) == 1) {
895+
output <- matrix(rep(input, output_rows * output_cols), ncol = output_cols)
896+
} else if (is.numeric(input)) {
897+
if (length(input) == output_cols) {
898+
output <- matrix(rep(input, output_rows), nrow=output_rows, byrow = T)
899+
} else if (length(input) == output_rows) {
900+
output <- matrix(rep(input, output_cols), ncol=output_cols, byrow = F)
901+
} else {
902+
stop("If `input` is a vector, it must either contain `output_rows` or `output_cols` elements")
903+
}
904+
} else if (is.matrix(input)) {
905+
if (nrow(input) != output_rows) {
906+
stop("`input` must be a matrix with `output_rows` rows")
907+
}
908+
if (ncol(input) != output_cols) {
909+
stop("`input` must be a matrix with `output_cols` columns")
910+
}
911+
output <- input
912+
} else {
913+
stop("`input` must be either a matrix, vector or a scalar")
914+
}
915+
return(output)
916+
}
917+
918+
#' Convert scalar input to square matrix of dimension `output_size` x `output_size` with `input` along the diagonal,
919+
#' or check that input array is equivalent to a square matrix of dimension `output_size` x `output_size`.
920+
#'
921+
#' @param input Input to be converted to a square matrix (or passed through as-is)
922+
#' @param output_size Intended row and column dimension of the square output matrix
923+
#' @return A square matrix of dimension `output_size` x `output_size`
924+
#' @export
925+
expand_dims_2d_diag <- function(input, output_size) {
926+
if (length(input) == 1) {
927+
output <- as.matrix(diag(input, output_size))
928+
} else if (is.matrix(input)) {
929+
if (nrow(input) != ncol(input)) {
930+
stop("`input` must be a square matrix")
931+
}
932+
if (nrow(input) != output_size) {
933+
stop("`input` must be a square matrix with `output_size` rows and columns")
934+
}
935+
output <- input
936+
} else {
937+
stop("`input` must be either a square matrix or a scalar")
938+
}
939+
return(output)
940+
}

man/bart.Rd

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)