|
13 | 13 | # You should have received a copy of the GNU General Public License along with |
14 | 14 | # this program; if not, see <http://www.gnu.org/licenses/>. |
15 | 15 |
|
| 16 | + |
| 17 | +.rename_scalar <- function(sso, oldname = "lp__", newname = "log-posterior") { |
| 18 | + p <- which(sso@param_names == oldname) |
| 19 | + if (identical(integer(0), p)) |
| 20 | + return(sso) |
| 21 | + sso@param_names[p] <- |
| 22 | + dimnames(sso@samps_all)$parameters[p] <- |
| 23 | + names(sso@param_dims)[which(names(sso@param_dims) == oldname)] <- newname |
| 24 | + sso |
| 25 | +} |
| 26 | + |
16 | 27 | # convert stanfit object to shinystan object |
17 | 28 | stan2shinystan <- function(stanfit, model_name, notes) { |
18 | 29 | # notes: text to add to user_model_info slot |
19 | 30 | rstan_check() |
20 | | - if (!inherits(stanfit, "stanfit")) { |
| 31 | + if (!is.stanfit(stanfit)) { |
21 | 32 | name <- deparse(substitute(stanfit)) |
22 | 33 | stop(paste(name, "is not a stanfit object.")) |
23 | 34 | } |
24 | 35 |
|
25 | | - stan_args <- stanfit@stan_args[[1]] |
| 36 | + stan_args <- stanfit@stan_args[[1L]] |
| 37 | + stan_method <- stan_args$method |
| 38 | + vb <- stan_method == "variational" |
26 | 39 | from_cmdstan_csv <- ("engine" %in% names(stan_args)) |
27 | 40 | stan_algorithm <- if (from_cmdstan_csv) |
28 | 41 | toupper(stan_args$engine) else stan_args$algorithm |
29 | 42 | warmup <- if (from_cmdstan_csv) stanfit@sim$warmup2[1L] else stanfit@sim$warmup |
30 | 43 | nWarmup <- if (from_cmdstan_csv) warmup else floor(warmup / stanfit@sim$thin) |
31 | 44 |
|
32 | | - max_td <- stanfit@stan_args[[1]]$control |
33 | | - if (is.null(max_td)) |
| 45 | + cntrl <- stanfit@stan_args[[1L]]$control |
| 46 | + if (is.null(cntrl)) |
34 | 47 | max_td <- 11 |
35 | 48 | else { |
36 | | - max_td <- max_td$max_treedepth |
| 49 | + max_td <- cntrl$max_treedepth |
37 | 50 | if (is.null(max_td)) |
38 | 51 | max_td <- 11 |
39 | 52 | } |
40 | 53 |
|
41 | 54 | samps_all <- rstan::extract(stanfit, permuted = FALSE, inc_warmup = TRUE) |
42 | | - param_names <- dimnames(samps_all)[[3]] # stanfit@sim$fnames_oi |
| 55 | + param_names <- dimnames(samps_all)[[3L]] # stanfit@sim$fnames_oi |
43 | 56 | param_dims <- stanfit@sim$dims_oi |
44 | 57 |
|
45 | | - if (!(stan_algorithm %in% c("NUTS", "HMC"))) { |
46 | | - warning("Most shinyStan features are only available for models using |
| 58 | + if (!vb && !(stan_algorithm %in% c("NUTS", "HMC"))) { |
| 59 | + warning("Most features are only available for models using |
47 | 60 | algorithm NUTS or algorithm HMC.") |
48 | 61 | } |
49 | | - |
50 | 62 | mname <- if (!missing(model_name)) model_name else stanfit@model_name |
51 | 63 | mcode <- rstan::get_stancode(stanfit) |
52 | 64 |
|
| 65 | + sampler_params <- if (vb) list(NA) else rstan::get_sampler_params(stanfit) |
| 66 | + stan_summary <- rstan::summary(stanfit)$summary |
| 67 | + if (vb) stan_summary <- cbind(stan_summary, Rhat = NA, n_eff = NA, se_mean = NA) |
| 68 | + |
53 | 69 | slots <- list() |
54 | 70 | slots$Class <- "shinystan" |
55 | 71 | slots$model_name <- mname |
56 | 72 | slots$param_names <- param_names |
57 | 73 | slots$param_dims <- param_dims |
58 | 74 | slots$samps_all <- samps_all |
59 | | - slots$summary <- rstan::summary(stanfit)$summary |
60 | | - slots$sampler_params <- rstan::get_sampler_params(stanfit) |
| 75 | + slots$summary <- stan_summary |
| 76 | + slots$sampler_params <- sampler_params |
61 | 77 | slots$nChains <- ncol(stanfit) |
62 | 78 | slots$nIter <- nrow(samps_all) |
63 | 79 | slots$nWarmup <- nWarmup |
64 | 80 | if (!missing(notes)) slots$user_model_info <- notes |
65 | 81 | if (length(mcode) > 0) slots$model_code <- mcode |
66 | | - slots$misc <- list(max_td = max_td, stan_algorithm = stan_algorithm) |
67 | | - do.call("new", slots) |
| 82 | + slots$misc <- list(max_td = max_td, stan_method = stan_method, |
| 83 | + stan_algorithm = stan_algorithm) |
| 84 | + sso <- do.call("new", slots) |
| 85 | + .rename_scalar(sso) |
68 | 86 | } |
| 87 | + |
0 commit comments