Load packages
library(rstanarm)
library(loo)
library(ggplot2)
library(bayesplot)
theme_set(bayesplot::theme_default())
library(projpred)
SEED=170701694
This notebook was inspired by Eric Novik’s slides “Deconstructing Stan Manual Part 1: Linear”. The idea is to demonstrate how easy it is to do good variable selection with rstanarm
, loo
, and projpred
.
In this notebook we illustrate Bayesian inference for model selection, including PSIS-LOO (Vehtari, Gelman and Gabry, 2017) and projection predictive approach (Piironen and Vehtari, 2017a; Piironen, Paasiniemi and Vehtari, 2020; McLatchie et al., 2023) which makes decision theoretically justified inference after model selection..
We use Wine quality data set from UCI Machine Learning repository
d <- read.delim("winequality-red.csv", sep = ";")
dim(d)
[1] 1599 12
Remove duplicated
d <- d[!duplicated(d), ] # remove the duplicates
(p <- ncol(d))
[1] 12
(n <- nrow(d))
[1] 1359
names(d)
[1] "fixed.acidity" "volatile.acidity" "citric.acid"
[4] "residual.sugar" "chlorides" "free.sulfur.dioxide"
[7] "total.sulfur.dioxide" "density" "pH"
[10] "sulphates" "alcohol" "quality"
prednames <- names(d)[1:(p-1)]
We scale the covariates so that when looking at the marginal posteriors for the effects they are on the same scale.
ds <- scale(d)
winequality_red <- as.data.frame(ds)
The rstanarm
package provides stan_glm
which accepts same arguments as glm
, but makes full Bayesian inference using Stan (mc-stan.org). By default a weakly informative Gaussian prior is used for weights.
model_formula <- formula(paste("quality ~", paste(prednames, collapse = " + ")))
fitg <- stan_glm(model_formula, data = winequality_red, QR=TRUE,
seed=SEED, refresh=0)
Let’s look at the summary:
summary(fitg)
Model Info:
function: stan_glm
family: gaussian [identity]
formula: quality ~ fixed.acidity + volatile.acidity + citric.acid + residual.sugar +
chlorides + free.sulfur.dioxide + total.sulfur.dioxide +
density + pH + sulphates + alcohol
algorithm: sampling
sample: 4000 (posterior sample size)
priors: see help('prior_summary')
observations: 1359
predictors: 12
Estimates:
mean sd 10% 50% 90%
(Intercept) 0.0 0.0 0.0 0.0 0.0
fixed.acidity 0.0 0.1 -0.1 0.0 0.1
volatile.acidity -0.2 0.0 -0.3 -0.2 -0.2
citric.acid 0.0 0.0 -0.1 0.0 0.0
residual.sugar 0.0 0.0 0.0 0.0 0.0
chlorides -0.1 0.0 -0.1 -0.1 -0.1
free.sulfur.dioxide 0.0 0.0 0.0 0.0 0.1
total.sulfur.dioxide -0.1 0.0 -0.2 -0.1 -0.1
density 0.0 0.1 -0.1 0.0 0.0
pH -0.1 0.0 -0.1 -0.1 0.0
sulphates 0.2 0.0 0.2 0.2 0.2
alcohol 0.4 0.0 0.3 0.4 0.4
sigma 0.8 0.0 0.8 0.8 0.8
Fit Diagnostics:
mean sd 10% 50% 90%
mean_PPD 0.0 0.0 0.0 0.0 0.0
The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).
MCMC diagnostics
mcse Rhat n_eff
(Intercept) 0.0 1.0 5009
fixed.acidity 0.0 1.0 5523
volatile.acidity 0.0 1.0 5615
citric.acid 0.0 1.0 5884
residual.sugar 0.0 1.0 4943
chlorides 0.0 1.0 5330
free.sulfur.dioxide 0.0 1.0 5237
total.sulfur.dioxide 0.0 1.0 5496
density 0.0 1.0 4811
pH 0.0 1.0 5149
sulphates 0.0 1.0 5378
alcohol 0.0 1.0 5213
sigma 0.0 1.0 4864
mean_PPD 0.0 1.0 4282
log-posterior 0.1 1.0 1578
For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
We didn’t get divergences, Rhat’s are less than 1.1 and n_eff’s are useful (see, e.g., RStan workflow).
mcmc_areas(as.matrix(fitg), pars=prednames, prob_outer = .95)
Several 95% posterior intervals are not overlapping 0, so maybe there is something useful here.
In case of collinear variables it is possible that marginal posteriors overlap 0, but the covariates can still useful for prediction. With many variables it will be difficult to analyse joint posterior to see which variables are jointly relevant. We can easily test whether any of the covariates are useful by using cross-validation to compare to a null model,
fitg0 <- stan_glm(quality ~ 1, data = winequality_red, seed=SEED, refresh=0)
We use fast Pareto smoothed importance sampling leave-one-out cross-validation (Vehtari, Gelman and Gabry, 2017)
(loog <- loo(fitg))
Computed from 4000 by 1359 log-likelihood matrix
Estimate SE
elpd_loo -1635.7 30.6
p_loo 16.6 1.6
looic 3271.4 61.2
------
Monte Carlo SE of elpd_loo is 0.1.
Pareto k diagnostic values:
Count Pct. Min. n_eff
(-Inf, 0.5] (good) 1358 99.9% 1600
(0.5, 0.7] (ok) 1 0.1% 1270
(0.7, 1] (bad) 0 0.0% <NA>
(1, Inf) (very bad) 0 0.0% <NA>
All Pareto k estimates are ok (k < 0.7).
See help('pareto-k-diagnostic') for details.
(loog0 <- loo(fitg0))
Computed from 4000 by 1359 log-likelihood matrix
Estimate SE
elpd_loo -1929.9 28.3
p_loo 2.2 0.2
looic 3859.9 56.5
------
Monte Carlo SE of elpd_loo is 0.0.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
loo_compare(loog0, loog)
elpd_diff se_diff
fitg 0.0 0.0
fitg0 -294.3 22.9
Based on cross-validation covariates together have a high predictive power. If we need just the predictions we can stop here, but if we want to learn more about the relevance of the covariates we can continue with variable selection.
We make the projective predictive variable selection (Piironen and Vehtari, 2017a; Piironen, Paasiniemi and Vehtari, 2020) using projpred
package. A fast PSIS-LOO (Vehtari, Gelman and Gabry, 2017) is used to choose the model size. As the number of observations is large compared to the number of covariates, we estimate the performance using LOO-CV only along the search path (validate_search=FALSE
), as we may assume that the overfitting in search is negligible (see more about this in McLatchie et al. (2023)).
fitg_cv <- cv_varsel(fitg, method='forward', cv_method='loo', validate_search=FALSE)
We can now look at the estimated predictive performance of smaller models compared to the full model.
plot(fitg_cv, stats = c('elpd', 'rmse'), text_angle = 45)
Three or four variables seems to be needed to get the same performance as the full model. As the estimated predictive performance is not going much above the reference model performance, we know that the use of option validate_search=FALSE
was safe (see more in McLatchie et al. (2023)).
We can get a loo-cv based recommendation for the model size to choose.
(nsel <- suggest_size(fitg_cv, alpha=0.1))
[1] 4
(vsel <- solution_terms(fitg_cv)[1:nsel])
[1] "alcohol" "volatile.acidity" "sulphates" "chlorides"
projpred recommends to use four variables: alcohol, volatile.acidity, sulphates, and chlorides.
Next we form the projected posterior for the chosen model. This projected model can be used in the future to make predictions by using only the selected variables.
projg <- project(fitg_cv, nv = nsel, ns = 4000)
|
| | 0%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
|
|== | 4%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 6%
|
|===== | 7%
|
|===== | 8%
|
|====== | 8%
|
|====== | 9%
|
|======= | 10%
|
|======== | 11%
|
|======== | 12%
|
|========= | 12%
|
|========= | 13%
|
|========= | 14%
|
|========== | 14%
|
|========== | 15%
|
|=========== | 15%
|
|=========== | 16%
|
|============ | 16%
|
|============ | 17%
|
|============ | 18%
|
|============= | 18%
|
|============= | 19%
|
|============== | 20%
|
|=============== | 21%
|
|=============== | 22%
|
|================ | 22%
|
|================ | 23%
|
|================ | 24%
|
|================= | 24%
|
|================= | 25%
|
|================== | 25%
|
|================== | 26%
|
|=================== | 26%
|
|=================== | 27%
|
|=================== | 28%
|
|==================== | 28%
|
|==================== | 29%
|
|===================== | 30%
|
|====================== | 31%
|
|====================== | 32%
|
|======================= | 32%
|
|======================= | 33%
|
|======================= | 34%
|
|======================== | 34%
|
|======================== | 35%
|
|========================= | 35%
|
|========================= | 36%
|
|========================== | 36%
|
|========================== | 37%
|
|========================== | 38%
|
|=========================== | 38%
|
|=========================== | 39%
|
|============================ | 40%
|
|============================= | 41%
|
|============================= | 42%
|
|============================== | 42%
|
|============================== | 43%
|
|============================== | 44%
|
|=============================== | 44%
|
|=============================== | 45%
|
|================================ | 45%
|
|================================ | 46%
|
|================================= | 46%
|
|================================= | 47%
|
|================================= | 48%
|
|================================== | 48%
|
|================================== | 49%
|
|=================================== | 50%
|
|==================================== | 51%
|
|==================================== | 52%
|
|===================================== | 52%
|
|===================================== | 53%
|
|===================================== | 54%
|
|====================================== | 54%
|
|====================================== | 55%
|
|======================================= | 55%
|
|======================================= | 56%
|
|======================================== | 56%
|
|======================================== | 57%
|
|======================================== | 58%
|
|========================================= | 58%
|
|========================================= | 59%
|
|========================================== | 60%
|
|=========================================== | 61%
|
|=========================================== | 62%
|
|============================================ | 62%
|
|============================================ | 63%
|
|============================================ | 64%
|
|============================================= | 64%
|
|============================================= | 65%
|
|============================================== | 65%
|
|============================================== | 66%
|
|=============================================== | 66%
|
|=============================================== | 67%
|
|=============================================== | 68%
|
|================================================ | 68%
|
|================================================ | 69%
|
|================================================= | 70%
|
|================================================== | 71%
|
|================================================== | 72%
|
|=================================================== | 72%
|
|=================================================== | 73%
|
|=================================================== | 74%
|
|==================================================== | 74%
|
|==================================================== | 75%
|
|===================================================== | 75%
|
|===================================================== | 76%
|
|====================================================== | 76%
|
|====================================================== | 77%
|
|====================================================== | 78%
|
|======================================================= | 78%
|
|======================================================= | 79%
|
|======================================================== | 80%
|
|========================================================= | 81%
|
|========================================================= | 82%
|
|========================================================== | 82%
|
|========================================================== | 83%
|
|========================================================== | 84%
|
|=========================================================== | 84%
|
|=========================================================== | 85%
|
|============================================================ | 85%
|
|============================================================ | 86%
|
|============================================================= | 86%
|
|============================================================= | 87%
|
|============================================================= | 88%
|
|============================================================== | 88%
|
|============================================================== | 89%
|
|=============================================================== | 90%
|
|================================================================ | 91%
|
|================================================================ | 92%
|
|================================================================= | 92%
|
|================================================================= | 93%
|
|================================================================= | 94%
|
|================================================================== | 94%
|
|================================================================== | 95%
|
|=================================================================== | 95%
|
|=================================================================== | 96%
|
|==================================================================== | 96%
|
|==================================================================== | 97%
|
|==================================================================== | 98%
|
|===================================================================== | 98%
|
|===================================================================== | 99%
|
|======================================================================| 100%
round(colMeans(as.matrix(projg)), 1)
(Intercept) alcohol volatile.acidity
0.0 0.4 -0.3
sulphates chlorides total.sulfur.dioxide
0.2 -0.1 -0.1
sigma
0.8
round(posterior_interval(as.matrix(projg)), 1)
5% 95%
(Intercept) 0.0 0.0
alcohol 0.3 0.4
volatile.acidity -0.3 -0.2
sulphates 0.2 0.2
chlorides -0.1 -0.1
total.sulfur.dioxide -0.1 0.0
sigma 0.8 0.8
The marginals of projected posteriors look like this.
mcmc_areas(as.matrix(projg), pars = vsel)
We also test regularized horseshoe prior (Piironen and Vehtari, 2017b) which has more prior mass near 0.
p0 <- 5 # prior guess for the number of relevant variables
tau0 <- p0/(p-p0) * 1/sqrt(n)
hs_prior <- hs(df=1, global_df=1, global_scale=tau0)
fitrhs <- stan_glm(model_formula, data = winequality_red, prior=hs_prior,
seed=SEED, refresh=0)
mcmc_areas(as.matrix(fitrhs), pars=prednames, prob_outer = .95)
Many of the variables are shrunk more towards 0, but still based on these marginals it is not as easy to select the most useful variables as it is with projpred.
The posteriors with normal and regularized horseshoe priors are clearly different, but does this have an effect to the predictions? In case of collinearity prior may have a strong effect on posterior, but a weak effect on posterior predictions. We can use loo to compare
(loorhs <- loo(fitrhs))
Computed from 4000 by 1359 log-likelihood matrix
Estimate SE
elpd_loo -1634.2 30.5
p_loo 13.8 1.3
looic 3268.3 61.0
------
Monte Carlo SE of elpd_loo is 0.1.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
loo_compare(loog, loorhs)
elpd_diff se_diff
fitrhs 0.0 0.0
fitg -1.5 1.4
There is no difference in predictive performance and thus we don’t need to repeat the projpred variable selection for the model with regularized horseshoe prior.
McLatchie, Y., Rögnvaldsson, S., Weber, F. and Vehtari, A. (2023) ‘Robust and efficient projection predictive inference’, arXiv preprint arXiv:2306.15581.
Piironen, J., Paasiniemi, M. and Vehtari, A. (2020) ‘Projective inference in high-dimensional problems: Prediction and feature selection’, Electronic Journal of Statistics, 14(1), pp. 2155–2197.
Piironen, J. and Vehtari, A. (2017a) ‘Comparison of Bayesian predictive methods for model selection’, Statistics and Computing, 27(3), pp. 711–735. doi: 10.1007/s11222-016-9649-y.
Piironen, J. and Vehtari, A. (2017b) ‘Sparsity information and regularization in the horseshoe and other shrinkage priors’, Electronic journal of Statistics, 11(2), pp. 5018–5051. doi: 10.1214/17-EJS1337SI.
Vehtari, A., Gelman, A. and Gabry, J. (2017) ‘Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC’, Statistics and Computing, 27(5), pp. 1413–1432. doi: 10.1007/s11222-016-9696-4.
sessionInfo()
R version 4.2.2 Patched (2022-11-10 r83330)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.3 LTS
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=fi_FI.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=fi_FI.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=fi_FI.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=fi_FI.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] cmdstanr_0.6.1.9000 brms_2.20.1 projpred_2.7.0
[4] bayesplot_1.10.0 GGally_2.1.2 bridgesampling_1.1-2
[7] ggridges_0.5.4 ggplot2_3.4.4 loo_2.6.0
[10] rstanarm_2.26.1 Rcpp_1.0.11 tidyr_1.3.0
loaded via a namespace (and not attached):
[1] minqa_1.2.5 colorspace_2.1-0 ellipsis_0.3.2
[4] markdown_1.8 QuickJSR_1.0.5 base64enc_0.1-3
[7] farver_2.1.1 rstan_2.26.23 DT_0.29
[10] fansi_1.0.5 mvtnorm_1.2-3 codetools_0.2-19
[13] splines_4.2.2 cachem_1.0.8 knitr_1.43
[16] shinythemes_1.2.0 jsonlite_1.8.7 nloptr_2.0.3
[19] shiny_1.7.5 compiler_4.2.2 backports_1.4.1
[22] Matrix_1.5-1 fastmap_1.1.1 cli_3.6.1
[25] later_1.3.1 htmltools_0.5.6 prettyunits_1.1.1
[28] tools_4.2.2 igraph_1.5.1 coda_0.19-4
[31] gtable_0.3.4 glue_1.6.2 reshape2_1.4.4
[34] dplyr_1.1.3 posterior_1.4.1 jquerylib_0.1.4
[37] vctrs_0.6.4 nlme_3.1-162 crosstalk_1.2.0
[40] tensorA_0.36.2 xfun_0.40 stringr_1.5.0
[43] ps_1.7.5 lme4_1.1-34 mime_0.12
[46] miniUI_0.1.1.1 lifecycle_1.0.3 gtools_3.9.4
[49] MASS_7.3-58.2 zoo_1.8-12 scales_1.2.1
[52] colourpicker_1.3.0 hms_1.1.3 promises_1.2.1
[55] Brobdingnag_1.2-9 parallel_4.2.2 inline_0.3.19
[58] shinystan_2.6.0 RColorBrewer_1.1-3 yaml_2.3.7
[61] gridExtra_2.3 StanHeaders_2.26.28 sass_0.4.7
[64] reshape_0.8.9 stringi_1.7.12 highr_0.10
[67] dygraphs_1.1.1.6 checkmate_2.3.0 boot_1.3-28
[70] pkgbuild_1.4.2 rlang_1.1.1 pkgconfig_2.0.3
[73] matrixStats_1.0.0 distributional_0.3.2 evaluate_0.21
[76] lattice_0.20-45 purrr_1.0.2 rstantools_2.3.1.1
[79] htmlwidgets_1.6.2 labeling_0.4.3 processx_3.8.2
[82] tidyselect_1.2.0 plyr_1.8.8 magrittr_2.0.3
[85] R6_2.5.1 generics_0.1.3 pillar_1.9.0
[88] withr_2.5.1 xts_0.13.1 survival_3.4-0
[91] abind_1.4-5 tibble_3.2.1 crayon_1.5.2
[94] utf8_1.2.4 rmarkdown_2.24 progress_1.2.2
[97] grid_4.2.2 data.table_1.14.8 callr_3.7.3
[100] threejs_0.3.3 digest_0.6.33 xtable_1.8-4
[103] httpuv_1.6.11 RcppParallel_5.1.7 stats4_4.2.2
[106] munsell_0.5.0 bslib_0.5.1 shinyjs_2.1.0