Load packages
library(tidyr)
library(rstanarm)
library(loo)
library(ggplot2)
theme_set(bayesplot::theme_default())
library(ggridges)
library(bridgesampling)
This notebook demonstrates a simple model we trust (no model misspecification). In this case, cross-validation (or other model selection appraoch) is not needed, and we can get better accuracy using the explicit model.
An experiment was performed to estimate the effect of beta-blockers on mortality of cardiac patients (the example is from Gelman et al., 2013, Ch 3). A group of patients were randomly assigned to treatment and control groups:
Data, where grp2
is a dummy variable that captures the difference of the intercepts in the first and the second group.
d_bin2 <- data.frame(N = c(674, 680), y = c(39,22), grp2 = c(0,1))
To analyse whether the treatment is useful, we can use Binomial model for both groups and compute odds-ratio.
fit_bin2 <- stan_glm(y/N ~ grp2, family = binomial(), data = d_bin2,
weights = N, refresh=0)
In general we recommend showing the full posterior of the quantity of interest, which in this case is the odds ratio.
samples_bin2 <- rstan::extract(fit_bin2$stanfit)
theta1 <- plogis(samples_bin2$alpha)
theta2 <- plogis(samples_bin2$alpha + samples_bin2$beta)
oddsratio <- (theta2/(1-theta2))/(theta1/(1-theta1))
ggplot() + geom_histogram(aes(oddsratio), bins = 50, fill = 'grey', color = 'darkgrey') +
labs(y = '') + scale_y_continuous(breaks = NULL)
We can compute the probability that odds-ratio is less than 1:
print(mean(oddsratio<1),2)
[1] 0.99
This posterior distribution of the odds-ratio (or some transformation of it) is the simplest and the most accurate way to analyse the effectiveness of the treatment. In this case, there is high probability that the treatment is effective and relatively big. Additional observations would be helpful to reduce the uncertainty.
Although we recommend showing the full posterior, the probability that oddsratio < 1 can be a useful summary. Simulation experiment binom_odds_comparison.R
runs 100 simulations with simulated data with varying oddsratio (0.1,…,1.0) and computes for each run the probability that oddsratio<1. The following figures show the variation in the results.
Variation in probability that oddsratio<1 when true oddsratio is varied.
load(file="binom_test_densities.RData")
ggplot(betaprobs_densities, aes(x = values, y = ind, height = scaled)) +
geom_density_ridges(stat = "identity", scale=0.6)
We see that for small treatment effects, just by chance we can observe data that hve varying information about the latent treatment effect.
Sometimes it is better to focus on observable space (we can’t observe \(\theta\) or odds-ratio directly, but we can observe \(y\)). For example, in case of many collinear covariates, it can be difficult to interpret the posterior directly in the same way we can do in this simple example. In such cases, we may investigate the difference in the predictive performance.
In leave-one-out cross-validation, model is fitted \(n\) times with each observation left out at time in fitting and used to evaluate the predictive performance. This corresponds to using the already seen observations as pseudo Monte Carlo samples from the future data distribution, with the leave-trick used to avoid double use of data. With the often used log-score we get \[\mathrm{LOO} = \frac{1}{n} \sum_{i=1}^n \log {p(y_i|x_i,D_{-i},M_k)}.\]
Basic cross-validation makes only assumption that the future data comes from the same distribution as the observed data (weghted cross-validation can be used to handle moderate data shifts), but doesn’t make any model assumption about that distribution. This sis useful when we don’t trust any model (the models might include good enough models, but we just don’t know if that is the case).
Next we demonstrate one of the weaknesses of cross-validation (same holds for WAIC etc.).
To use leave-one-out where “one” refers to an individual patient, we need to change the model formulation a bit. In the above model formulation, the individual observations have been aggregated to group observations and running loo(fit_bin2)
would try to leave one group completely. In case of having more groups, this could be what we want, but in case of just two groups it is unlikely. Thus, in the following we switch to a Bernoulli model with each individual as it’s own observation.
d_bin2b <- data.frame(y = c(rep(1,39), rep(0,674-39), rep(1,22), rep(0,680-22)), grp2 = c(rep(0, 674), rep(1, 680)))
fit_bin2b <- stan_glm(y ~ grp2, family = binomial(), data = d_bin2b, seed=180202538, refresh=0)
We fit also a “null” model which doesn’t use the group variable and thus has common parameter for both groups.
fit_bin2bnull <- stan_glm(y ~ 1, family = binomial(), data = d_bin2b, seed=180202538, refresh=0)
We can then use cross-validation to compare whether adding the treatment variable improves predictive performance. We use fast Pareto smoothed importance sampling leave-one-out cross-validation (PSIS-LOO; Vehtari, Gelman and Gabry, 2017).
(loo_bin2 <- loo(fit_bin2b))
Computed from 4000 by 1354 log-likelihood matrix
Estimate SE
elpd_loo -248.1 23.3
p_loo 2.0 0.2
looic 496.1 46.6
------
Monte Carlo SE of elpd_loo is 0.0.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
(loo_bin2null <- loo(fit_bin2bnull))
Computed from 4000 by 1354 log-likelihood matrix
Estimate SE
elpd_loo -249.7 23.4
p_loo 1.0 0.1
looic 499.4 46.7
------
Monte Carlo SE of elpd_loo is 0.0.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
All Pareto \(k<0.5\) and we can trust PSIS-LOO computation (Vehtari, Gelman and Gabry, 2017; Vehtari et al., 2022).
We make a pairwise comparison.
loo_compare(loo_bin2null, loo_bin2)
elpd_diff se_diff
fit_bin2b 0.0 0.0
fit_bin2bnull -1.6 2.3
elpd_diff
is small compared to diff_se
, and thus cross-validation is uncertain whether estimating the treatment effect improves the predictive performance. To put this in perspective, we have \(N_1=674\) and \(N_2=680\), and 5.8% and 3.2% deaths, which is too weak information for cross-validation.
Simulation experiment binom_odds_comparison.R
runs 100 simulations with simulated data with varying oddsratio (0.1,…,1.0) and computes LOO comparison for each run.
Variation in LOO comparison when true oddsratio is varied.
ggplot(looprobs_densities, aes(x = values, y = ind, height = scaled)) +
geom_density_ridges(stat = "identity", scale=0.6)
We see that using the posterior distribution from the model is more efficient to detect the effect, but cross-validation will detect it eventually too. The difference here comes that cross-validation doesn’t trust the model, compares the model predictions to the “future data” using very weak assumption about the future, which leads to higher variance of the estimates. The weak assumption about the future is also the cross-validation strength as we’ll see in another notebook.
We can also do predictive performance estimates using stronger assumption about the future. A reference predictive estimate with log-score can be computed as \[ \mathrm{elpd}_{\mathrm{ref}} = \int p(\tilde{y}|D,M_*) \log p(\tilde{y}|D,M_k) d\tilde{y}, \] where \(M_*\) is a reference model we trust. Using a reference model to assess the other models corresponds to \(M\)-completed case (Vehtari and Ojanen, 2012), where the true model is replaced with a model we trust to be close enough to the true model. The reference model approch has smaller variance than cross-validation, but it is biased towards the reference model, which means that the reference model should be carefully checked to not be in conflict with the observed data, and the the reference model approch provides the best predictive performance estimate for the reference model itself. Here we illustrate the reference model approach so that ech \(p(\tilde{y}|D,M_k)\)- is the usual posterior predictive distribution. Even better would be to use projection approach, which is demonstrated in other notebooks. See more about the decision theoretical justification of the reference and projection approaches in Section 3.3 of the review by Vehtari and Ojanen (2012), and experimental results by Piironen and Vehtari (2017).
The next figure shows the results from the same simulation study using a reference predictive approach with the fit_bin2
model used as the reference.
ggplot(refprobs_densities, aes(x = values, y = ind, height = scaled)) +
geom_density_ridges(stat = "identity", scale=0.6)
We can see better accuracy than for cross-validation. We also see, especially when there is no treatment effect that the reference model approach is favoring the reference model itself.
The similar and even bigger improvement in the model selection performance is observed in projection predictive variable selection (Piironen and Vehtari, 2017; Piironen, Paasiniemi and Vehtari, 2020; McLatchie et al., 2023) implemented in projpred
package.
As comparison we include marginal likelihood based approach to compute the posterior probabilities for the null model (treatment effect is zero) and the model with unknown treatment effect. As the data and models are very simple, we may assume that the model is well specified. Marginal likelihoods and relative posterior probabilities can be sensitive to the selected prior on the bigger model. Here we simply use the same rstanarm
default prior as for the above examples. Marginal likelihoods are computed using the default bridge sampling approach implemented in bridge_sampling
package.
# rerun models with diagnostic file required by bridge_sampler
fit_bin2 <- stan_glm(y/N ~ grp2, family = binomial(), data = d_bin2,
weights = N, refresh=0,
diagnostic_file = file.path(tempdir(), "df.csv"))
(ml_bin2 <- bridge_sampler(fit_bin2, silent=TRUE))
Bridge sampling estimate of the log marginal likelihood: -11.47109
Estimate obtained in 5 iteration(s) via method "normal".
fit_bin2null <- stan_glm(y/N ~ 1, family = binomial(), data = d_bin2,
weights = N, refresh=0,
diagnostic_file = file.path(tempdir(), "df.csv"))
(ml_bin2null <- bridge_sampler(fit_bin2null, silent=TRUE))
Bridge sampling estimate of the log marginal likelihood: -11.46144
Estimate obtained in 4 iteration(s) via method "normal".
print(post_prob(ml_bin2, ml_bin2null), digits=2)
ml_bin2 ml_bin2null
0.5 0.5
Posterior probability computed from the marginal likelihoods is indecisive.
We repeat the simulation with marginal likelihood approach.
ggplot(bfprobs_densities, aes(x = values, y = ind, height = scaled)) +
geom_density_ridges(stat = "identity", scale=0.6)
We can see that marginal likelihood based approach favors more strongly null model for smaller treatment effects, requires a bigger effect than the other approaches to not favor the null model, but given big enough effect is more decisive on non-null model than cross-validation.
Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A. and Rubin, D. B. (2013) Bayesian data analysis, third edition. CRC Press.
McLatchie, Y., Rögnvaldsson, S., Weber, F. and Vehtari, A. (2023) ‘Robust and efficient projection predictive inference’, arXiv preprint arXiv:2306.15581.
Piironen, J., Paasiniemi, M. and Vehtari, A. (2020) ‘Projective inference in high-dimensional problems: Prediction and feature selection’, Electronic Journal of Statistics, 14(1), pp. 2155–2197.
Piironen, J. and Vehtari, A. (2017) ‘Comparison of Bayesian predictive methods for model selection’, Statistics and Computing, 27(3), pp. 711–735. doi: 10.1007/s11222-016-9649-y.
Vehtari, A., Gelman, A. and Gabry, J. (2017) ‘Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC’, Statistics and Computing, 27(5), pp. 1413–1432. doi: 10.1007/s11222-016-9696-4.
Vehtari, A. and Ojanen, J. (2012) ‘A survey of Bayesian predictive methods for model assessment, selection and comparison’, Statistics Surveys, 6, pp. 142–228. doi: 10.1214/12-SS102.
Vehtari, A., Simpson, D., Gelman, A., Yao, Y. and Gabry, J. (2022) ‘Pareto smoothed importance sampling’, arXiv preprint arXiv:1507.02646. Available at: https://arxiv.org/abs/1507.02646v6.
sessionInfo()
R version 4.2.2 Patched (2022-11-10 r83330)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.3 LTS
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=fi_FI.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=fi_FI.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=fi_FI.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=fi_FI.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] bridgesampling_1.1-2 ggridges_0.5.4 ggplot2_3.4.4
[4] loo_2.6.0 rstanarm_2.26.1 Rcpp_1.0.11
[7] tidyr_1.3.0
loaded via a namespace (and not attached):
[1] nlme_3.1-162 matrixStats_1.0.0 xts_0.13.1
[4] threejs_0.3.3 rstan_2.26.23 tensorA_0.36.2
[7] backports_1.4.1 tools_4.2.2 bslib_0.5.1
[10] utf8_1.2.4 R6_2.5.1 DT_0.29
[13] colorspace_2.1-0 withr_2.5.1 Brobdingnag_1.2-9
[16] tidyselect_1.2.0 gridExtra_2.3 prettyunits_1.1.1
[19] processx_3.8.2 compiler_4.2.2 cli_3.6.1
[22] shinyjs_2.1.0 labeling_0.4.3 colourpicker_1.3.0
[25] posterior_1.4.1 sass_0.4.7 checkmate_2.3.0
[28] scales_1.2.1 dygraphs_1.1.1.6 mvtnorm_1.2-3
[31] callr_3.7.3 QuickJSR_1.0.5 stringr_1.5.0
[34] digest_0.6.33 StanHeaders_2.26.28 minqa_1.2.5
[37] rmarkdown_2.24 base64enc_0.1-3 pkgconfig_2.0.3
[40] htmltools_0.5.6 lme4_1.1-34 highr_0.10
[43] fastmap_1.1.1 htmlwidgets_1.6.2 rlang_1.1.1
[46] shiny_1.7.5 farver_2.1.1 jquerylib_0.1.4
[49] generics_0.1.3 zoo_1.8-12 jsonlite_1.8.7
[52] crosstalk_1.2.0 gtools_3.9.4 distributional_0.3.2
[55] dplyr_1.1.3 inline_0.3.19 magrittr_2.0.3
[58] bayesplot_1.10.0 Matrix_1.5-1 munsell_0.5.0
[61] fansi_1.0.5 abind_1.4-5 lifecycle_1.0.3
[64] stringi_1.7.12 yaml_2.3.7 MASS_7.3-58.2
[67] pkgbuild_1.4.2 plyr_1.8.8 grid_4.2.2
[70] parallel_4.2.2 promises_1.2.1 crayon_1.5.2
[73] miniUI_0.1.1.1 lattice_0.20-45 splines_4.2.2
[76] knitr_1.43 ps_1.7.5 pillar_1.9.0
[79] igraph_1.5.1 boot_1.3-28 markdown_1.8
[82] shinystan_2.6.0 reshape2_1.4.4 codetools_0.2-19
[85] stats4_4.2.2 rstantools_2.3.1.1 glue_1.6.2
[88] evaluate_0.21 RcppParallel_5.1.7 vctrs_0.6.4
[91] nloptr_2.0.3 httpuv_1.6.11 gtable_0.3.4
[94] purrr_1.0.2 cachem_1.0.8 xfun_0.40
[97] mime_0.12 xtable_1.8-4 coda_0.19-4
[100] later_1.3.1 survival_3.4-0 tibble_3.2.1
[103] shinythemes_1.2.0 ellipsis_0.3.2