|
47 | 47 | #' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`.
|
48 | 48 | #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
|
49 | 49 | #' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
|
| 50 | +#' - `rfx_working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. |
| 51 | +#' - `rfx_group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. |
| 52 | +#' - `rfx_working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. |
| 53 | +#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. |
| 54 | +#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. |
| 55 | +#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. |
50 | 56 | #'
|
51 | 57 | #' @param prognostic_forest_params (Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional.
|
52 | 58 | #'
|
@@ -162,7 +168,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
|
162 | 168 | treated_coding_init = 0.5, rfx_prior_var = NULL,
|
163 | 169 | random_seed = -1, keep_burnin = FALSE, keep_gfr = FALSE,
|
164 | 170 | keep_every = 1, num_chains = 1, verbose = FALSE,
|
165 |
| - probit_outcome_model = FALSE |
| 171 | + probit_outcome_model = FALSE, |
| 172 | + rfx_working_parameter_prior_mean = NULL, |
| 173 | + rfx_group_parameter_prior_mean = NULL, |
| 174 | + rfx_working_parameter_prior_cov = NULL, |
| 175 | + rfx_group_parameter_prior_cov = NULL, |
| 176 | + rfx_variance_prior_shape = 1, |
| 177 | + rfx_variance_prior_scale = 1 |
166 | 178 | )
|
167 | 179 | general_params_updated <- preprocessParams(
|
168 | 180 | general_params_default, general_params
|
@@ -230,6 +242,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
|
230 | 242 | num_chains <- general_params_updated$num_chains
|
231 | 243 | verbose <- general_params_updated$verbose
|
232 | 244 | probit_outcome_model <- general_params_updated$probit_outcome_model
|
| 245 | + rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean |
| 246 | + rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean |
| 247 | + rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov |
| 248 | + rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov |
| 249 | + rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape |
| 250 | + rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale |
233 | 251 |
|
234 | 252 | # 2. Mu forest parameters
|
235 | 253 | num_trees_mu <- prognostic_forest_params_updated$num_trees
|
@@ -842,24 +860,39 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
|
842 | 860 |
|
843 | 861 | # Random effects prior parameters
|
844 | 862 | if (has_rfx) {
|
845 |
| - # Initialize the working parameter to 1 |
846 |
| - if (num_rfx_components < 1) { |
847 |
| - stop("There must be at least 1 random effect component") |
| 863 | + # Prior parameters |
| 864 | + if (is.null(rfx_working_parameter_prior_mean)) { |
| 865 | + if (num_rfx_components == 1) { |
| 866 | + alpha_init <- c(1) |
| 867 | + } else if (num_rfx_components > 1) { |
| 868 | + alpha_init <- c(1,rep(0,num_rfx_components-1)) |
| 869 | + } else { |
| 870 | + stop("There must be at least 1 random effect component") |
| 871 | + } |
| 872 | + } else { |
| 873 | + alpha_init <- expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components) |
| 874 | + } |
| 875 | + |
| 876 | + if (is.null(rfx_group_parameter_prior_mean)) { |
| 877 | + xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups) |
| 878 | + } else { |
| 879 | + xi_init <- expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups) |
848 | 880 | }
|
849 |
| - alpha_init <- rep(1,num_rfx_components) |
850 |
| - # Initialize each group parameter based on a regression of outcome on basis in that grou |
851 |
| - xi_init <- matrix(0,num_rfx_components,num_rfx_groups) |
852 |
| - for (i in 1:num_rfx_groups) { |
853 |
| - group_subset_indices <- rfx_group_ids_train == i |
854 |
| - basis_group <- rfx_basis_train[group_subset_indices,] |
855 |
| - resid_group <- resid_train[group_subset_indices] |
856 |
| - rfx_group_model <- lm(resid_group ~ 0+basis_group) |
857 |
| - xi_init[,i] <- unname(coef(rfx_group_model)) |
| 881 | + |
| 882 | + if (is.null(rfx_working_parameter_prior_cov)) { |
| 883 | + sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components) |
| 884 | + } else { |
| 885 | + sigma_alpha_init <- expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components) |
| 886 | + } |
| 887 | + |
| 888 | + if (is.null(rfx_group_parameter_prior_cov)) { |
| 889 | + sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components) |
| 890 | + } else { |
| 891 | + sigma_xi_init <- expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components) |
858 | 892 | }
|
859 |
| - sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components) |
860 |
| - sigma_xi_init <- diag(rfx_prior_var) |
861 |
| - sigma_xi_shape <- 1 |
862 |
| - sigma_xi_scale <- 1 |
| 893 | + |
| 894 | + sigma_xi_shape <- rfx_variance_prior_shape |
| 895 | + sigma_xi_scale <- rfx_variance_prior_scale |
863 | 896 | }
|
864 | 897 |
|
865 | 898 | # Random effects data structure and storage container
|
|
0 commit comments