Load packages
library(rstanarm)
options(mc.cores = 1)
library(loo)
library(ggplot2)
library(GGally)
library(bayesplot)
theme_set(bayesplot::theme_default())
library(projpred)
SEED=87
This notebook was inspired by Andrew Tyre’s blog post Does model averaging make sense?. Tyre discusses problems in current statistical practices in ecology, focusing in multi-collinearity, model averaging and measuring the relative importance of variables. Tyre’s post is commenting a paper Model averaging and muddled multimodel inferences In his blog post he uses maximum likelihood and AIC_c. Here we provide a Bayesian approach for handling multicollinearity, model averaging and measuring relative importance of variables using packages rstanarm
, bayesplot
, loo
and projpred
. We demonstrate the benefits of Bayesian posterior analysis (Gelman et al., 2013) and projection predictive approach (Piironen and Vehtari, 2017; Piironen, Paasiniemi and Vehtari, 2020; McLatchie et al., 2023).
We generate the data used previously to illustrate multi-collinearity problems.
# all this data generation is from Cade 2015
# doesn't matter what this is -- if you use a different number your results will be different from mine.
set.seed(SEED)
data <- tibble(
pos.tot = runif(200,min=0.8,max=1.0),
urban.tot = pmin(runif(200,min=0.0,max=0.02),1.0 - pos.tot),
neg.tot = (1.0 - pmin(pos.tot + urban.tot,1)),
x1= pmax(pos.tot - runif(200,min=0.05,max=0.30),0),
x3= pmax(neg.tot - runif(200,min=0.0,max=0.10),0),
x2= pmax(pos.tot - x1 - x3/2,0),
x4= pmax(1 - x1 - x2 - x3 - urban.tot,0))
# true model and 200 Poisson observations
mean.y <- exp(-5.8 + 6.3*data$x1 + 15.2*data$x2)
data$y <- rpois(200,mean.y)
ggpairs(data,diag=list(continuous="barDiag"))
Tyre: “So there is a near perfect negative correlation between the things sage grouse like and the things they don’t like, although it gets less bad when considering the individual covariates.”
From this point onwards we switch to Bayesian approach. The rstanarm package provides stan_glm
function which accepts same arguments as glm
, but makes full Bayesian inference using Stan (mc-stan.org). By default a weakly informative Gaussian prior is used for weights.
fitg <- stan_glm(y ~ x1 + x2 + x3 + x4, data = data, na.action = na.fail,
family=poisson(), seed=SEED, refresh=0)
Let’s look at the summary:
summary(fitg)
Model Info:
function: stan_glm
family: poisson [log]
formula: y ~ x1 + x2 + x3 + x4
algorithm: sampling
sample: 4000 (posterior sample size)
priors: see help('prior_summary')
observations: 200
predictors: 5
Estimates:
mean sd 10% 50% 90%
(Intercept) 2.1 5.7 -5.0 2.0 9.4
x1 -1.6 5.7 -8.9 -1.5 5.6
x2 6.9 5.7 -0.4 7.0 14.0
x3 -8.5 5.8 -15.9 -8.4 -1.2
x4 -8.4 6.0 -16.0 -8.3 -0.7
Fit Diagnostics:
mean sd 10% 50% 90%
mean_PPD 4.3 0.2 4.1 4.3 4.6
The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).
MCMC diagnostics
mcse Rhat n_eff
(Intercept) 0.2 1.0 611
x1 0.2 1.0 612
x2 0.2 1.0 613
x3 0.2 1.0 625
x4 0.2 1.0 629
mean_PPD 0.0 1.0 2962
log-posterior 0.0 1.0 1284
For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
We didn’t get divergences, Rhat’s are less than 1.01 and n_eff’s are useful (see, e.g., RStan workflow). However, when we know that covariats are correlating we can get even better performance by using QR decomposition (see, The QR Decomposition For Regression Models).
fitg <- stan_glm(y ~ x1 + x2 + x3 + x4, data = data, na.action = na.fail,
family=poisson(), QR=TRUE, seed=SEED, refresh=0)
Let’s look at the summary and plot:
summary(fitg)
Model Info:
function: stan_glm
family: poisson [log]
formula: y ~ x1 + x2 + x3 + x4
algorithm: sampling
sample: 4000 (posterior sample size)
priors: see help('prior_summary')
observations: 200
predictors: 5
Estimates:
mean sd 10% 50% 90%
(Intercept) 2.5 6.1 -5.4 2.6 10.2
x1 -2.0 6.1 -9.8 -2.1 6.0
x2 6.5 6.1 -1.4 6.4 14.6
x3 -8.9 6.1 -16.5 -8.9 -1.0
x4 -8.7 6.3 -16.9 -8.8 -0.6
Fit Diagnostics:
mean sd 10% 50% 90%
mean_PPD 4.3 0.2 4.1 4.3 4.6
The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).
MCMC diagnostics
mcse Rhat n_eff
(Intercept) 0.1 1.0 4112
x1 0.1 1.0 4114
x2 0.1 1.0 4116
x3 0.1 1.0 4259
x4 0.1 1.0 3946
mean_PPD 0.0 1.0 4127
log-posterior 0.0 1.0 1742
For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
Use of QR decomposition greatly improved sampling efficiency and we continue with this model.
mcmc_areas(as.matrix(fitg), prob_outer = .99)
All 95% posterior intervals are overlapping 0 and it seems we have the same collinearity problem as with maximum likelihood estimates.
Looking at the pairwise posteriors we can see high correlations
mcmc_pairs(as.matrix(fitg), pars = c("x1","x2","x3","x4"))
If look more carefully on of the subplots, we see that although marginal posterior intervals overlap 0, the joint posterior is not overlapping 0.
mcmc_scatter(as.matrix(fitg), pars = c("x1", "x2"))+geom_vline(xintercept=0)+geom_hline(yintercept=0)
Based on the joint distributions all the variables would be relevant. To make predictions we don’t need to make variable selection, we just integrate over the uncertainty (kind of continuous model averaging).
In case of even more variables with some being relevant and some irrelevant, it will be difficult to analyse joint posterior to see which variables are jointly relevant. We can easily test whether any of the covariates are useful by using cross-validation to compare to a null model,
fitg0 <- stan_glm(y ~ 1, data = data, na.action = na.fail,
family=poisson(), seed=SEED, refresh=0)
(loog <- loo(fitg))
Computed from 4000 by 200 log-likelihood matrix
Estimate SE
elpd_loo -383.3 12.1
p_loo 5.3 0.7
looic 766.6 24.2
------
Monte Carlo SE of elpd_loo is 0.0.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
(loog0 <- loo(fitg0))
Computed from 4000 by 200 log-likelihood matrix
Estimate SE
elpd_loo -714.5 43.8
p_loo 4.9 0.8
looic 1428.9 87.5
------
Monte Carlo SE of elpd_loo is 0.1.
All Pareto k estimates are good (k < 0.5).
See help('pareto-k-diagnostic') for details.
loo_compare(loog0, loog)
elpd_diff se_diff
fitg 0.0 0.0
fitg0 -331.2 42.7
Based on cross-validation covariates together contain significant information to improve predictions.
We might want to choose some variables 1) because we don’t want to observe all the variables in the future (e.g. due to the measurement cost), or 2) we want to most relevant variables which we define here as a minimal set of variables which can provide similar predictions to the full model.
Tyre used AIC_c to estimate the model performance. In Bayesian setting we could use Bayesian cross-validation or WAIC, but we don’t recommend thhem for variable selection as discussed by Piironen and Vehtari (2017). The reason for not using Bayesian CV or WAIC is that the selection process uses the data twice, and in case of large number variable combinations the selection process overfits and can produce really bad models. Using the usual posterior inference given the selected variables ignores that the selected variables are conditonal on the selection process and simply setting some variables to 0 ignores the uncertainty related to their relevance.
Piironen and Vehtari (2017) also show that a projection predictive approach can be used to make a model reduction, that is, choosing a smaller model with some coefficients set to 0. The projection predictive approach solves the problem how to do inference after the selection. The solution is to project the full model posterior to the restricted subspace. See more by Piironen, Paasiniemi and Vehtari (2020) and McLatchie et al. (2023).
We make the projective predictive variable selection using the previous full model. A fast leave-one-out cross-validation approach (Vehtari, Gelman and Gabry, 2017) is used to choose the model size. As the number of observations is large compared to the number of covariates, we estimate the performance using LOO-CV only along the search path (validate_search=FALSE
), as we may assume that the overfitting in search is negligible (see more about this in McLatchie et al. (2023)).
cvvs <- cv_varsel(fitg, method='forward', cv_method='loo', validate_search=FALSE)
We can now look at the estimated predictive performance of smaller models compared to the full model.
plot(cvvs, stats = c('elpd', 'rmse'))
As the estimated predictive performance is not going much above the reference model performance, we know that the use of option validate_search=FALSE
was safe (see more in McLatchie et al. (2023)).
And we get a LOO based recommendation for the model size to choose
(nsel <- suggest_size(cvvs))
[1] 2
(vsel <- solution_terms(cvvs)[1:nsel])
[1] "x2" "x1"
We see that 2 variables is enough to get the same predictive accuracy as with all 4 variables.
Next we form the projected posterior for the chosen model.
projg <- project(cvvs, nv = nsel, ns = 4000)
|
| | 0%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
|
|== | 4%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 6%
|
|===== | 7%
|
|===== | 8%
|
|====== | 8%
|
|====== | 9%
|
|======= | 10%
|
|======== | 11%
|
|======== | 12%
|
|========= | 12%
|
|========= | 13%
|
|========= | 14%
|
|========== | 14%
|
|========== | 15%
|
|=========== | 15%
|
|=========== | 16%
|
|============ | 16%
|
|============ | 17%
|
|============ | 18%
|
|============= | 18%
|
|============= | 19%
|
|============== | 20%
|
|=============== | 21%
|
|=============== | 22%
|
|================ | 22%
|
|================ | 23%
|
|================ | 24%
|
|================= | 24%
|
|================= | 25%
|
|================== | 25%
|
|================== | 26%
|
|=================== | 26%
|
|=================== | 27%
|
|=================== | 28%
|
|==================== | 28%
|
|==================== | 29%
|
|===================== | 30%
|
|====================== | 31%
|
|====================== | 32%
|
|======================= | 32%
|
|======================= | 33%
|
|======================= | 34%
|
|======================== | 34%
|
|======================== | 35%
|
|========================= | 35%
|
|========================= | 36%
|
|========================== | 36%
|
|========================== | 37%
|
|========================== | 38%
|
|=========================== | 38%
|
|=========================== | 39%
|
|============================ | 40%
|
|============================= | 41%
|
|============================= | 42%
|
|============================== | 42%
|
|============================== | 43%
|
|============================== | 44%
|
|=============================== | 44%
|
|=============================== | 45%
|
|================================ | 45%
|
|================================ | 46%
|
|================================= | 46%
|
|================================= | 47%
|
|================================= | 48%
|
|================================== | 48%
|
|================================== | 49%
|
|=================================== | 50%
|
|==================================== | 51%
|
|==================================== | 52%
|
|===================================== | 52%
|
|===================================== | 53%
|
|===================================== | 54%
|
|====================================== | 54%
|
|====================================== | 55%
|
|======================================= | 55%
|
|======================================= | 56%
|
|======================================== | 56%
|
|======================================== | 57%
|
|======================================== | 58%
|
|========================================= | 58%
|
|========================================= | 59%
|
|========================================== | 60%
|
|=========================================== | 61%
|
|=========================================== | 62%
|
|============================================ | 62%
|
|============================================ | 63%
|
|============================================ | 64%
|
|============================================= | 64%
|
|============================================= | 65%
|
|============================================== | 65%
|
|============================================== | 66%
|
|=============================================== | 66%
|
|=============================================== | 67%
|
|=============================================== | 68%
|
|================================================ | 68%
|
|================================================ | 69%
|
|================================================= | 70%
|
|================================================== | 71%
|
|================================================== | 72%
|
|=================================================== | 72%
|
|=================================================== | 73%
|
|=================================================== | 74%
|
|==================================================== | 74%
|
|==================================================== | 75%
|
|===================================================== | 75%
|
|===================================================== | 76%
|
|====================================================== | 76%
|
|====================================================== | 77%
|
|====================================================== | 78%
|
|======================================================= | 78%
|
|======================================================= | 79%
|
|======================================================== | 80%
|
|========================================================= | 81%
|
|========================================================= | 82%
|
|========================================================== | 82%
|
|========================================================== | 83%
|
|========================================================== | 84%
|
|=========================================================== | 84%
|
|=========================================================== | 85%
|
|============================================================ | 85%
|
|============================================================ | 86%
|
|============================================================= | 86%
|
|============================================================= | 87%
|
|============================================================= | 88%
|
|============================================================== | 88%
|
|============================================================== | 89%
|
|=============================================================== | 90%
|
|================================================================ | 91%
|
|================================================================ | 92%
|
|================================================================= | 92%
|
|================================================================= | 93%
|
|================================================================= | 94%
|
|================================================================== | 94%
|
|================================================================== | 95%
|
|=================================================================== | 95%
|
|=================================================================== | 96%
|
|==================================================================== | 96%
|
|==================================================================== | 97%
|
|==================================================================== | 98%
|
|===================================================================== | 98%
|
|===================================================================== | 99%
|
|======================================================================| 100%
projdraws <- as.matrix(projg)
round(colMeans(projdraws),1)
(Intercept) x2 x1
-6.2 15.3 6.8
round(posterior_interval(projdraws),1)
5% 95%
(Intercept) -7.1 -5.3
x2 14.1 16.5
x1 5.9 7.9
This looks good as the true values are intercept=-5.8, x2=15.2, x1=6.3.
mcmc_areas(projdraws, pars=c("(Intercept)", vsel))
Even if we started with a model which had due to a collinearity difficult to interpret posterior, the projected posterior is able to match closely the true values. The necessary information was in the full model and with the projection we were able to form the projected posterior which we should use if x3 and x4 are set to 0.
Back to the Tyre’s question “Does model averaging make sense?”. If we are interested just in good predictions we can do continuous model averaging by using suitable priors and by integrating over the posterior. If we are intersted in predcitions, then we don’t first average weights (ie posterior mean), but use all weight values to compute predictions and do the averaging of the predictions. All this is automatic in Bayesian framework.
Tyre also commented on the problems of measuring variable importance. The projection predictive approach above is derived using decision theory and is very helpful for measuring relevancy and choosing relevant variables. Tyre did not comment about the inference after selection although it is also known problem in variable selection. The projection predictive approach above solves that problem, too.
Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A. and Rubin, D. B. (2013) Bayesian data analysis, third edition. CRC Press.
McLatchie, Y., Rögnvaldsson, S., Weber, F. and Vehtari, A. (2023) ‘Robust and efficient projection predictive inference’, arXiv preprint arXiv:2306.15581.
Piironen, J., Paasiniemi, M. and Vehtari, A. (2020) ‘Projective inference in high-dimensional problems: Prediction and feature selection’, Electronic Journal of Statistics, 14(1), pp. 2155–2197.
Piironen, J. and Vehtari, A. (2017) ‘Comparison of Bayesian predictive methods for model selection’, Statistics and Computing, 27(3), pp. 711–735. doi: 10.1007/s11222-016-9649-y.
Vehtari, A., Gelman, A. and Gabry, J. (2017) ‘Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC’, Statistics and Computing, 27(5), pp. 1413–1432. doi: 10.1007/s11222-016-9696-4.
sessionInfo()
R version 4.2.2 Patched (2022-11-10 r83330)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.3 LTS
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=fi_FI.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=fi_FI.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=fi_FI.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=fi_FI.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] projpred_2.7.0 bayesplot_1.10.0 GGally_2.1.2
[4] bridgesampling_1.1-2 ggridges_0.5.4 ggplot2_3.4.4
[7] loo_2.6.0 rstanarm_2.26.1 Rcpp_1.0.11
[10] tidyr_1.3.0
loaded via a namespace (and not attached):
[1] minqa_1.2.5 colorspace_2.1-0 ellipsis_0.3.2
[4] markdown_1.8 QuickJSR_1.0.5 base64enc_0.1-3
[7] farver_2.1.1 rstan_2.26.23 DT_0.29
[10] fansi_1.0.5 mvtnorm_1.2-3 codetools_0.2-19
[13] splines_4.2.2 cachem_1.0.8 knitr_1.43
[16] shinythemes_1.2.0 jsonlite_1.8.7 nloptr_2.0.3
[19] shiny_1.7.5 compiler_4.2.2 backports_1.4.1
[22] Matrix_1.5-1 fastmap_1.1.1 cli_3.6.1
[25] later_1.3.1 htmltools_0.5.6 prettyunits_1.1.1
[28] tools_4.2.2 igraph_1.5.1 coda_0.19-4
[31] gtable_0.3.4 glue_1.6.2 reshape2_1.4.4
[34] dplyr_1.1.3 posterior_1.4.1 jquerylib_0.1.4
[37] vctrs_0.6.4 nlme_3.1-162 crosstalk_1.2.0
[40] tensorA_0.36.2 xfun_0.40 stringr_1.5.0
[43] ps_1.7.5 lme4_1.1-34 mime_0.12
[46] miniUI_0.1.1.1 lifecycle_1.0.3 gtools_3.9.4
[49] MASS_7.3-58.2 zoo_1.8-12 scales_1.2.1
[52] colourpicker_1.3.0 hms_1.1.3 promises_1.2.1
[55] Brobdingnag_1.2-9 parallel_4.2.2 inline_0.3.19
[58] shinystan_2.6.0 RColorBrewer_1.1-3 yaml_2.3.7
[61] gridExtra_2.3 StanHeaders_2.26.28 sass_0.4.7
[64] reshape_0.8.9 stringi_1.7.12 highr_0.10
[67] dygraphs_1.1.1.6 checkmate_2.3.0 boot_1.3-28
[70] pkgbuild_1.4.2 rlang_1.1.1 pkgconfig_2.0.3
[73] matrixStats_1.0.0 distributional_0.3.2 evaluate_0.21
[76] lattice_0.20-45 purrr_1.0.2 rstantools_2.3.1.1
[79] htmlwidgets_1.6.2 labeling_0.4.3 processx_3.8.2
[82] tidyselect_1.2.0 plyr_1.8.8 magrittr_2.0.3
[85] R6_2.5.1 generics_0.1.3 pillar_1.9.0
[88] withr_2.5.1 xts_0.13.1 survival_3.4-0
[91] abind_1.4-5 tibble_3.2.1 crayon_1.5.2
[94] utf8_1.2.4 rmarkdown_2.24 progress_1.2.2
[97] grid_4.2.2 callr_3.7.3 threejs_0.3.3
[100] digest_0.6.33 xtable_1.8-4 httpuv_1.6.11
[103] RcppParallel_5.1.7 stats4_4.2.2 munsell_0.5.0
[106] bslib_0.5.1 shinyjs_2.1.0