Student grades and variable selection with projpred

Author

Aki Vehtari

Published

2023-12-14

Modified

2024-04-18

1 Portugal students data

We work with an example of predicting mathematics and Portuguese exam grades for a sample of high school students in Portugal. The same data was used in Chapter 12 of Regression and Other Stories book to illustrate different models for regression coefficients .

We predict the students’ final-year median exam grade in mathematics (n=382) and Portuguese (n=657) given a large number of potentially relevant predictors: student’s school, student’s sex, student’s age, student’s home address type, family size, parents’ cohabitation status, mother’s education, father’s education, home-to-school travel time, weekly study time, number of past class failures, extra educational support, extra paid classes within the course subject, extra-curricular activities, whether the student attended nursery school, whether the student wants to take higher education, internet access at home, whether the student has a romantic relationship, quality of family relationships, free time after school, going out with friends, weekday alcohol consumption, weekend alcohol consumption, current health status, and number of school absences.

2 Variable selection

If we would care only about the predictive performance, we would not need to do variable selection, but we would use all the variables and a sensible joint prior. Here we are interested in finding the smallest set of variables that provide similar predictive performance as using all the variables (and sensible prior). This helps to improve explainability and to design further studies that could include also interventions. We are not considering causal structure, and the selected variables are unlikely to have direct causal effect, but the selected variables that have high predictive relevance are such that their role in causal graph should be eventually considered.

We first build models with all predictors, and then use projection predictive variable selection (Piironen, Paasiniemi, and Vehtari 2020; McLatchie et al. 2023) implemented in R package projpred

Load packages

Code
library("brms")
options(brms.backend="cmdstanr")
library(cmdstanr)
options(mc.cores = parallel::detectCores()-2)
library("posterior")
options(digits=2, posterior.digits=2,
        pillar.neg = FALSE, pillar.subtle=FALSE, pillar.sigfig=2)
library("loo")
library("projpred")
devtools::load_all('~/proj/projpred')
library("ggplot2")
library("bayesplot")
theme_set(bayesplot::theme_default(base_family = "sans", base_size=16))
set1 <- RColorBrewer::brewer.pal(7, "Set1")
library(tinytable)
options(tinytable_format_num_fmt = "significant_cell", tinytable_format_digits = 2, tinytable_tt_digits=2)
library("dplyr")
library("matrixStats")

Set random seed for reproducibility

SEED <- 2132

3 Data

Get the data from Regression and Other Stories R package.

student <- read.csv(url('https://raw.githubusercontent.com/avehtari/ROS-Examples/master/Student/data/student-merged.csv'))

List the predictors to be used.

predictors <- c("school","sex","age","address","famsize","Pstatus","Medu","Fedu","traveltime","studytime","failures","schoolsup","famsup","paid","activities", "nursery", "higher", "internet", "romantic","famrel","freetime","goout","Dalc","Walc","health","absences")
p <- length(predictors)

Compute median mathematics and Portuguese grades based on three exams for each topic. Select only students with non-zero grades.

grades <- c("G1mat","G2mat","G3mat","G1por","G2por","G3por")
student <- student %>%
  mutate(across(matches("G[1-3]..."), ~na_if(.,0))) %>%
  mutate(Gmat = rowMedians(as.matrix(select(.,matches("G.mat"))), na.rm=TRUE),
         Gpor = rowMedians(as.matrix(select(.,matches("G.por"))), na.rm=TRUE))
student_Gmat <- subset(student, is.finite(Gmat), select=c("Gmat",predictors))
student_Gmat <- student_Gmat[is.finite(rowMeans(student_Gmat)),]
student_Gpor <- subset(student, is.finite(Gpor), select=c("Gpor",predictors))
(nmat <- nrow(student_Gmat))
[1] 382
head(student_Gmat) |> tt()
tinytable_vpmsljjdatvz9rxqe6fl
Gmat school sex age address famsize Pstatus Medu Fedu traveltime studytime failures schoolsup famsup paid activities nursery higher internet romantic famrel freetime goout Dalc Walc health absences
10 0 0 15 0 0 1 1 1 2 4 1 1 1 1 1 1 1 1 0 3 1 2 1 1 1 2
6 0 0 15 0 0 1 1 1 1 2 2 1 1 0 0 0 1 1 1 3 3 4 2 4 5 2
13 0 0 15 0 0 1 2 2 1 1 0 1 1 1 1 1 1 0 0 4 3 1 1 1 2 8
9 0 0 15 0 0 1 2 4 1 3 0 1 1 1 1 1 1 1 0 4 3 2 1 1 5 2
10 0 0 15 0 0 1 3 3 2 3 2 0 1 1 1 1 1 1 1 4 2 1 2 3 3 8
12 0 0 15 0 0 1 3 4 1 3 0 1 1 1 1 1 1 1 0 4 3 2 1 1 5 2
(npor <- nrow(student_Gpor))
[1] 382
head(student_Gpor) |> tt()
tinytable_a1xfse1tm81hztv10p0w
Gpor school sex age address famsize Pstatus Medu Fedu traveltime studytime failures schoolsup famsup paid activities nursery higher internet romantic famrel freetime goout Dalc Walc health absences
13 0 0 15 0 0 1 1 1 2 4 1 1 1 1 1 1 1 1 0 3 1 2 1 1 1 2
11 0 0 15 0 0 1 1 1 1 2 2 1 1 0 0 0 1 1 1 3 3 4 2 4 5 2
13 0 0 15 0 0 1 2 2 1 1 0 1 1 1 1 1 1 0 0 4 3 1 1 1 2 8
10 0 0 15 0 0 1 2 4 1 3 0 1 1 1 1 1 1 1 0 4 3 2 1 1 5 2
13 0 0 15 0 0 1 3 3 2 3 2 0 1 1 1 1 1 1 1 4 2 1 2 3 3 8
12 0 0 15 0 0 1 3 4 1 3 0 1 1 1 1 1 1 1 0 4 3 2 1 1 5 2

Standardize all predictors for easier comparison of relevances as discussed in Regression and Other Stories Section 12.1.

studentstd_Gmat <- student_Gmat
Gmatbin<-apply(student_Gmat[,predictors], 2, function(x) {length(unique(x))==2})
studentstd_Gmat[,predictors[!Gmatbin]] <-scale(student_Gmat[,predictors[!Gmatbin]])
studentstd_Gpor <- student_Gpor
Gporbin<-apply(student_Gpor[,predictors], 2, function(x) {length(unique(x))==2})
studentstd_Gpor[,predictors[!Gporbin]] <-scale(student_Gpor[,predictors[!Gporbin]])

4 Prior

Before variable selection, we want to build a good model with all covariates. We first illustrate that common default priors may be bad when we have many predictors.

By default brms uses uniform priors on regression coefficients.

fitmu <- brm(Gmat ~ ., data = studentstd_Gmat,
             normalize=FALSE)

If we compare posterior-\(R^2\) (bayes_R2()) and LOO-\(R^2\) (loo_R2()) (Gelman et al. 2019), we see that the posterior-\(R^2\) is much higher which means that the posterior estimate for the residual variance is strongly underestimated and the model has overfitted the data.

bayes_R2(fitmu) |> as.data.frame() |> tt()
tinytable_2ejx649qjj1fn6tfnha0
Estimate Est.Error Q2.5 Q97.5
0.32 0.031 0.26 0.38
loo_R2(fitmu) |> as.data.frame() |> tt()
tinytable_36kya1u0l5c91uqmv188
Estimate Est.Error Q2.5 Q97.5
0.2 0.04 0.11 0.27

Flat priors are improper, and we can’t do prior predictive simulation with them. Regression and Other Stories Chapter 12 example shows prior predictive simulations with 1) the usual independent wide prior normal prior, 2) scaled normal prior, and 3) regularized horseshoe prior, illustrating the implied prior on \(R^2\).

Here we use R2D2 prior (Zhang et al. 2022) which can be used to define prior directly on \(R^2\). We assign prior with mean 1/3 and precision 3 (this corresponds to Beta(1,2) distribution on \(R^2\)).

fitm <- brm(Gmat ~ ., data = studentstd_Gmat,
            prior=c(prior(R2D2(mean_R2 = 1/3, prec_R2 = 3, cons_D2 = .3,
                               autoscale = FALSE),class=b),
                    prior(normal(0,1), class=sigma)),
            normalize=FALSE)
#| cache: TRUE
fitm <- add_criterion(fitm, criterion='loo')

Posterior-\(R^2\) (bayes_R2()) and LOO-\(R^2\) (loo_R2()) are now more similar indicating that the prior is not pushing the posterior towards higher \(R^2\) values.

bayes_R2(fitm) |> as.data.frame() |> tt()
tinytable_71zb7myubradryiu2tco
Estimate Est.Error Q2.5 Q97.5
0.24 0.036 0.17 0.31
loo_R2(fitm) |> as.data.frame() |> tt()
tinytable_hwg1hlt3afiaveept5s8
Estimate Est.Error Q2.5 Q97.5
0.21 0.031 0.15 0.27

We plot the marginal posteriors for coefficients.

drawsm <- as_draws_df(fitm, variable=paste0('b_',predictors)) |>
  set_variables(predictors)
p <- mcmc_areas(drawsm,
                 prob_outer=0.98, area_method = "scaled height") +
  xlim(c(-3,3))
p <- p + scale_y_discrete(limits = rev(levels(p$data$parameter)))
p

For many coefficients the posterior has been shrunk close to 0. Some marginal posteriors are wide. We check the bivariate marginal for Fedu and Medu coefficients, and see that while the univariate marginals overlap with 0, jointly there is not much posterior mass near 0. This is due to Fedu and Medu being collinear. Collinearity of predictors, make it difficult to infer the predictor relevance from the marginal posteriors.

mcmc_scatter(drawsm, pars = c("Fedu","Medu")) +
  vline_0(linetype='dashed') +
  hline_0(linetype='dashed')

5 Projection predictive variable selection

We use projection predictive variable selection implemented in projpred R package to find the minimal set of predictors that can provide similar predictive performance as all predictor jointly. We start with doing fast PSIS-LOO-CV only for the full data search path.

vselm_fast <- cv_varsel(fitm, nterms_max = 27, validate_search = FALSE)

The following plot shows the relevance order of the predictors and estimated predictive performance given those variables. As the search can overfit and we didn’t cross-validate the search, the performance estimates can go above the reference model performance. However, this plot helps as to see that 10 or fewer predictors would be sufficient.

plot(vselm_fast, stats=c("elpd", "R2"), deltas = TRUE,
     text_angle = 45, alpha = 0.1, 
     size_position = "primary_x_top", show_cv_proportions=FALSE) +
  geom_vline(xintercept = seq(0, 25, by = 5), colour = "black", alpha = 0.1)

Next we repeat the search, but now cross-validate the search, too. We repeat the search with PSIS-LOO-CV criterion only for nloo=50 folds, and combine the result with the fast PSIS-LOO result using difference estimator (Magnusson et al. 2020). Based on the previous quick result, we search only up to models of size 10. To make the computation faster we use refit_prj=FALSE. For improve estimation stability, this can be omitted to use the default refit_prj=TRUE

library(doFuture)
library(tictoc)
registerDoFuture()
plan(multisession, workers = 8)
tic()
vselm <- cv_varsel(fitm, nterms_max = 10, validate_search = TRUE,
                   refit_prj = FALSE, nloo = 50,
                   parallel = TRUE)
toc()
plan(sequential)

The following plot shows the relevance order of the predictors and estimated predictive performance given those variables. The order is the same as in the previous plot, but now the predictive performance estimates are taking into account search and have smaller bias. It seems using just four predictors can provide the similar predictive performance as using all the predictors.

plot(vselm, stats=c("elpd","R2"), deltas=TRUE,
     text_angle=45, alpha=0.1, 
     size_position = 'primary_x_top', show_cv_proportions=FALSE) +
  geom_vline(xintercept=seq(0,10,by=5), colour='black', alpha=0.1)

projpred can also provide suggestion for the sufficient model size.

(nselm <- suggest_size(vselm))
[1] 4

Form the projected posterior for the selected model.

rankm <- ranking(vselm, nterms=nselm)
projm <- project(vselm, nterms=nselm)
drawsm_proj <- as_draws_df(projm) |>
  subset_draws(variable=paste0('b_',rankm$fulldata[1:nselm])) |>
  set_variables(variable=rankm$fulldata[1:nselm])

The marginals of the projected posterior are all clearly away from 0.

mcmc_areas(drawsm_proj, prob_outer=0.98, area_method = "scaled height")

The following plot shows the stability of the search over the different LOO-CV folds. The numbers indicate the proportion of folds, the specific predictor was included at latest on the given model size.

plot(cv_proportions(rankm, cumulate=TRUE))

6 Portuguese

We repeat the same, but predicting grade for Portuguese instead of mathematics

Fit a model with R2D2 prior with mean 1/3 and precision 3.

fitp <- brm(Gpor ~ ., data = studentstd_Gpor,
              prior=c(prior(R2D2(mean_R2 = 1/3, prec_R2 = 3, cons_D2 = .2,
                                autoscale = FALSE),class=b)))

Compare posterior-\(R^2\) and LOO-\(R^2\). We see that Portuguese grade is easier to predict given the predictors (but there is still a lot of unexplained variance).

fitp <- add_criterion(fitp, criterion='loo')
bayes_R2(fitp) |> round(2)
   Estimate Est.Error Q2.5 Q97.5
R2     0.31      0.04 0.23  0.38
loo_R2(fitp) |> round(2)
   Estimate Est.Error Q2.5 Q97.5
R2     0.28      0.03  0.2  0.34

Plot marginal posteriors of coefficients

drawsp <- as_draws_df(fitp, variable=paste0('b_',predictors)) |>
  set_variables(predictors)
p <- mcmc_areas(drawsp, prob_outer=0.98, area_method = "scaled height") +
  xlim(c(-3,3))
p <- p + scale_y_discrete(limits = rev(levels(p$data$parameter)))
p

We use projection predictive variable selection with fast LOO-CV of unexplained variance)only for the full data search path.

vselp_fast <- cv_varsel(fitp, nterms_max=27, validate_search=FALSE)

The following plot shows the relevance order of the predictors and estimated predictive performance given those variables. As there is some overfitting in the search and we didn’t cross-validate the search, the performance estimates scan go above the reference model performance. However, this plot helps as to see that 10 or fewer predictors would be sufficient.

plot(vselp_fast, stats = c("elpd","R2"), deltas=TRUE,
     text_angle=45, alpha=0.1, 
     size_position = 'primary_x_top', show_cv_proportions=FALSE) +
  geom_vline(xintercept=seq(0,25,by=5), colour='black', alpha=0.1)

Next we repeat the search, but now cross-validate the search, too. We use PSIS-LOO-CV and thus the search is repeated \(N\) times. Based on the previous quick result, we search only up to models of size 10. To make the computation faster we use refit_prj=FALSE. For improve estimation stability, this can be omitted to use the default refit_prj=TRUE

registerDoFuture()
plan(multisession, workers = 8)
tic()
vselp <- cv_varsel(fitp, nterms_max=10, validate_search=TRUE,
                   refit_prj=TRUE, nloo=50,
                   parallel=TRUE)
toc()
plan(sequential)

The following plot shows the relevance order of the predictors and estimated predictive performance given those variables. The order is the same as in the previous plot, but now the predictive performance estimates are taking into account search and have smaller bias. It seems using just seven predictors can provide the similar predictive performance as using all the predictors.

plot(vselp, stats=c("elpd", "R2"), deltas=TRUE,
     text_angle=45, alpha=0.1,
     size_position = 'primary_x_top', show_cv_proportions=FALSE) +
   geom_vline(xintercept=seq(0,10,by=5), colour='black', alpha=0.1)

projpred can also provide suggestion for the sufficient model size.

(nselp <- suggest_size(vselp))
[1] 7

Form the projected posterior for the selected model.

rankp <- ranking(vselp, nterms=nselp)
projp <- project(vselp, nterms=nselp)
drawsp_proj <- as_draws_df(projp) |>
  subset_draws(variable=paste0('b_',rankp$fulldata[1:nselp])) |>
  set_variables(variable=rankp$fulldata[1:nselp])

The marginals of the projected posterior are all clearly away from 0.

mcmc_areas(drawsp_proj, prob_outer=0.98, area_method = "scaled height")

The following plot shows the stability of the search over the different LOO-CV folds. The numbers indicate the proportion of folds, the specific predictor was included at latest on the given model size.

plot(cv_proportions(rankp, cumulate=TRUE))

References

Gelman, Andrew, Ben Goodrich, Jonah Gabry, and Aki Vehtari. 2019. “R-Squared for Bayesian Regression Models.” The American Statistician 73 (3): 307–9.
Magnusson, Måns, Michael Riis Andersen, Johan Jonasson, and Aki Vehtari. 2020. “Leave-One-Out Cross-Validation for Bayesian Model Comparison in Large Data.” In Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics (AISTATS), 108:341–51. PMLR.
McLatchie, Yann, Sölvi Rögnvaldsson, Frank Weber, and Aki Vehtari. 2023. “Robust and Efficient Projection Predictive Inference.” arXiv Preprint arXiv:2306.15581.
Piironen, Juho, Markus Paasiniemi, and Aki Vehtari. 2020. “Projective Inference in High-Dimensional Problems: Prediction and Feature Selection.” Electronic Journal of Statistics 14 (1): 2155–97.
Zhang, Yan Dora, Brian P. Naughton, Howard D. Bondell, and Brian J. Reich. 2022. “Bayesian Regression Using a Prior on the Model Fit: The R2-D2 Shrinkage Prior.” Journal of the American Statistical Association 117: 862–74.

Licenses

  • Code © 2023-2024, Aki Vehtari, licensed under BSD-3.
  • Text © 2023-2024, Aki Vehtari, licensed under CC-BY-NC 4.0.

Original Computing Environment

sessionInfo()
R version 4.3.3 (2024-02-29)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.4 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.10.0 
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.10.0

locale:
 [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_DK.UTF-8        LC_COLLATE=en_GB.UTF-8    
 [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
 [7] LC_PAPER=fi_FI.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       

time zone: Europe/Helsinki
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] matrixStats_1.2.0    dplyr_1.1.4          tinytable_0.2.1.2   
 [4] bayesplot_1.11.1     ggplot2_3.5.0        projpred_2.8.0.9000 
 [7] testthat_3.2.1       loo_2.7.0.9000       posterior_1.5.0.9000
[10] cmdstanr_0.7.1       brms_2.20.13         Rcpp_1.0.12         

loaded via a namespace (and not attached):
  [1] gridExtra_2.3        remotes_2.5.0        inline_0.3.19       
  [4] rlang_1.1.3          magrittr_2.0.3       ggridges_0.5.6      
  [7] compiler_4.3.3       mgcv_1.9-1           vctrs_0.6.5         
 [10] reshape2_1.4.4       stringr_1.5.1        profvis_0.3.8       
 [13] pkgconfig_2.0.3      fastmap_1.1.1        backports_1.4.1     
 [16] ellipsis_0.3.2       labeling_0.4.3       utf8_1.2.4          
 [19] threejs_0.3.3        promises_1.2.1       rmarkdown_2.25      
 [22] markdown_1.12        sessioninfo_1.2.2    nloptr_2.0.3        
 [25] ps_1.7.6             purrr_1.0.2          xfun_0.43           
 [28] cachem_1.0.8         jsonlite_1.8.8       later_1.3.2         
 [31] parallel_4.3.3       R6_2.5.1             dygraphs_1.1.1.6    
 [34] RColorBrewer_1.1-3   stringi_1.8.3        StanHeaders_2.32.5  
 [37] boot_1.3-30          pkgload_1.3.4        numDeriv_2016.8-1.1 
 [40] brio_1.1.4           rstan_2.32.5         knitr_1.45          
 [43] zoo_1.8-12           usethis_2.2.2        base64enc_0.1-3     
 [46] nnet_7.3-19          splines_4.3.3        httpuv_1.6.14       
 [49] Matrix_1.6-5         igraph_2.0.2         tidyselect_1.2.0    
 [52] rstudioapi_0.15.0    abind_1.4-5          yaml_2.3.8          
 [55] memisc_0.99.31.7     codetools_0.2-19     miniUI_0.1.1.1      
 [58] curl_5.2.0           processx_3.8.4       pkgbuild_1.4.4      
 [61] lattice_0.22-5       tibble_3.2.1         plyr_1.8.9          
 [64] shiny_1.8.0          withr_3.0.0          bridgesampling_1.1-2
 [67] ordinal_2023.12-4    coda_0.19-4.1        evaluate_0.23       
 [70] desc_1.4.3           RcppParallel_5.1.7   urlchecker_1.0.1    
 [73] xts_0.13.2           pillar_1.9.0         tensorA_0.36.2.1    
 [76] checkmate_2.3.1      DT_0.32              stats4_4.3.3        
 [79] shinyjs_2.1.0        distributional_0.4.0 generics_0.1.3      
 [82] rprojroot_2.0.4      rstantools_2.4.0     munsell_0.5.0       
 [85] scales_1.3.0         minqa_1.2.6          gtools_3.9.5        
 [88] xtable_1.8-4         gamm4_0.2-6          glue_1.7.0          
 [91] mclogit_0.9.6        tools_4.3.3          shinystan_2.6.0     
 [94] data.table_1.15.0    lme4_1.1-35.1        colourpicker_1.3.0  
 [97] fs_1.6.3             mvtnorm_1.2-4        grid_4.3.3          
[100] QuickJSR_1.1.3       crosstalk_1.2.1      devtools_2.4.5      
[103] colorspace_2.1-0     nlme_3.1-163         cli_3.6.2           
[106] fansi_1.0.6          Brobdingnag_1.2-9    V8_4.4.2            
[109] gtable_0.3.4         digest_0.6.35        ucminf_1.2.1        
[112] farver_2.1.1         htmlwidgets_1.6.4    memoise_2.0.1       
[115] htmltools_0.5.7      lifecycle_1.0.4      mime_0.12           
[118] MASS_7.3-60          shinythemes_1.2.0