Bayesian variable selection for red wine quality ranking data

Author

Aki Vehtari

Published

2018-02-27

Modified

2026-03-28

1 Introduction

This notebook was inspired by Eric Novik’s slides “Deconstructing Stan Manual Part 1: Linear”. The idea is to demonstrate how easy it is to do good variable selection with brms, loo, and projpred.

In this notebook we illustrate Bayesian inference for model selection, including PSIS-LOO (Vehtari, Gelman, and Gabry 2017) and projection predictive approach (McLatchie et al. 2023; Piironen, Paasiniemi, and Vehtari 2020; Piironen and Vehtari 2017) which makes decision theoretically justified inference after model selection.

Load packages

library("rprojroot")
root <- has_file(".casestudies-root")$make_fix_file()
library(dplyr)
library(brms)
options(brms.backend = "cmdstanr")
options(mc.cores = 4)
library(loo)
library(ggplot2)
library(bayesplot)
theme_set(bayesplot::theme_default(base_family = "sans", base_size = 14))
library(ggdist)
library(posterior)
library(projpred)
SEED <- 170701694

2 Wine quality data

We use Wine quality data set from UCI Machine Learning repository

wine <- read.delim(root("winequality-red", "winequality-red.csv"), sep = ";") |>
  distinct()
(p <- ncol(wine))
[1] 12
prednames <- names(wine)[1:(p-1)]
glimpse(wine)
Rows: 1,359
Columns: 12
$ fixed.acidity        <dbl> 7.4, 7.8, 7.8, 11.2, 7.4, 7.9, 7.3, 7.8, 7.5, 6.7…
$ volatile.acidity     <dbl> 0.700, 0.880, 0.760, 0.280, 0.660, 0.600, 0.650, …
$ citric.acid          <dbl> 0.00, 0.00, 0.04, 0.56, 0.00, 0.06, 0.00, 0.02, 0…
$ residual.sugar       <dbl> 1.9, 2.6, 2.3, 1.9, 1.8, 1.6, 1.2, 2.0, 6.1, 1.8,…
$ chlorides            <dbl> 0.076, 0.098, 0.092, 0.075, 0.075, 0.069, 0.065, …
$ free.sulfur.dioxide  <dbl> 11, 25, 15, 17, 13, 15, 15, 9, 17, 15, 16, 9, 52,…
$ total.sulfur.dioxide <dbl> 34, 67, 54, 60, 40, 59, 21, 18, 102, 65, 59, 29, …
$ density              <dbl> 0.9978, 0.9968, 0.9970, 0.9980, 0.9978, 0.9964, 0…
$ pH                   <dbl> 3.51, 3.20, 3.26, 3.16, 3.51, 3.30, 3.39, 3.36, 3…
$ sulphates            <dbl> 0.56, 0.68, 0.65, 0.58, 0.56, 0.46, 0.47, 0.57, 0…
$ alcohol              <dbl> 9.4, 9.8, 9.8, 9.8, 9.4, 9.4, 10.0, 9.5, 10.5, 9.…
$ quality              <int> 5, 5, 5, 6, 5, 5, 7, 7, 5, 5, 5, 5, 5, 5, 7, 5, 4…

We scale the covariates so that when looking at the marginal posteriors for the effects they are on the same scale.

wine_scaled <- as.data.frame(scale(wine))

3 Fit regression model

We use the brms package with R2D2 prior (Zhang et al. 2022) which provides adaptive shrinkage of regression coefficients.

fitg <- brm(quality ~ .,
            data = wine_scaled,
            prior = prior(R2D2(mean_R2 = 1/3, prec_R2 = 3)),
            seed = SEED,
            silent = 2,
            refresh = 0)

Let’s look at the summary:

summary(fitg)
 Family: gaussian 
  Links: mu = identity 
Formula: quality ~ fixed.acidity + volatile.acidity + citric.acid + residual.sugar + chlorides + free.sulfur.dioxide + total.sulfur.dioxide + density + pH + sulphates + alcohol 
   Data: wine_scaled (Number of observations: 1359) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Regression Coefficients:
                     Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
Intercept               -0.00      0.02    -0.04     0.04 1.00     6425
fixed.acidity            0.02      0.04    -0.05     0.11 1.00     3115
volatile.acidity        -0.25      0.03    -0.30    -0.19 1.00     4212
citric.acid             -0.02      0.03    -0.09     0.03 1.00     4270
residual.sugar           0.01      0.02    -0.04     0.05 1.00     3724
chlorides               -0.11      0.03    -0.16    -0.06 1.00     3895
free.sulfur.dioxide      0.03      0.03    -0.02     0.08 1.00     3039
total.sulfur.dioxide    -0.09      0.03    -0.16    -0.04 1.00     3256
density                 -0.01      0.03    -0.09     0.05 1.00     3516
pH                      -0.07      0.03    -0.14    -0.00 1.00     3224
sulphates                0.18      0.03     0.13     0.23 1.00     3432
alcohol                  0.38      0.03     0.32     0.44 1.00     3694
                     Tail_ESS
Intercept                3109
fixed.acidity            3509
volatile.acidity         3637
citric.acid              4135
residual.sugar           4143
chlorides                3204
free.sulfur.dioxide      3505
total.sulfur.dioxide     3265
density                  3330
pH                       1967
sulphates                3490
alcohol                  3654

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     0.80      0.02     0.77     0.83 1.00     6558     2494

Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Next do posterior predictive checking

pp_check(fitg)

Looking at this, we remember that the data are discrete quality rankings, and it would be better to use ordinal model.

fito <- brm(ordered(quality) ~ .,
            family = cumulative("logit"),
            data = wine_scaled,
            prior = prior(R2D2(mean_R2 = 1/3, prec_R2 = 3)),
            seed = SEED,
            silent = 2,
            refresh = 0)

Let’s look at the summary:

summary(fitg)
 Family: gaussian 
  Links: mu = identity 
Formula: quality ~ fixed.acidity + volatile.acidity + citric.acid + residual.sugar + chlorides + free.sulfur.dioxide + total.sulfur.dioxide + density + pH + sulphates + alcohol 
   Data: wine_scaled (Number of observations: 1359) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Regression Coefficients:
                     Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
Intercept               -0.00      0.02    -0.04     0.04 1.00     6425
fixed.acidity            0.02      0.04    -0.05     0.11 1.00     3115
volatile.acidity        -0.25      0.03    -0.30    -0.19 1.00     4212
citric.acid             -0.02      0.03    -0.09     0.03 1.00     4270
residual.sugar           0.01      0.02    -0.04     0.05 1.00     3724
chlorides               -0.11      0.03    -0.16    -0.06 1.00     3895
free.sulfur.dioxide      0.03      0.03    -0.02     0.08 1.00     3039
total.sulfur.dioxide    -0.09      0.03    -0.16    -0.04 1.00     3256
density                 -0.01      0.03    -0.09     0.05 1.00     3516
pH                      -0.07      0.03    -0.14    -0.00 1.00     3224
sulphates                0.18      0.03     0.13     0.23 1.00     3432
alcohol                  0.38      0.03     0.32     0.44 1.00     3694
                     Tail_ESS
Intercept                3109
fixed.acidity            3509
volatile.acidity         3637
citric.acid              4135
residual.sugar           4143
chlorides                3204
free.sulfur.dioxide      3505
total.sulfur.dioxide     3265
density                  3330
pH                       1967
sulphates                3490
alcohol                  3654

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma     0.80      0.02     0.77     0.83 1.00     6558     2494

Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Although in general we can’t directly compare continuous and discrete data models, when the target is integers, we can do direct comparison as explained in Nabiximols case study. We use fast Pareto smoothed importance sampling leave-one-out cross-validation (Vehtari, Gelman, and Gabry 2017)

loo_compare(loo(fitg), loo(fito))
     elpd_diff se_diff
fito    0.0       0.0 
fitg -286.7       7.2 

Ordinal model has much better predictive performance.

Ordinal model is flexible enough, that posterior predictive checks are unlikely to see any issues as explained in Recommendations for visual predictive checks in Bayesian workflow (Säilynoja et al. 2025). For example, the following bar plot is usually useless.

pp_check(fito, type="bars")

We can now examine posterior marginals.

drawso <- as_draws_df(fito) |>
  subset_draws(variable = paste0('b_', prednames)) |>
  set_variables(variable = prednames)
mcmc_areas(drawso, prob_outer = .95)

Several 95% posterior intervals are not overlapping 0, so maybe there is something useful here.

4 Projection predictive variable selection

We make the projective predictive variable selection (Piironen, Paasiniemi, and Vehtari 2020; Piironen and Vehtari 2017) using projpred package. A fast PSIS-LOO (Vehtari, Gelman, and Gabry 2017) is used to choose the model size. As the number of observations is large compared to the number of covariates, we estimate the performance using LOO-CV only along the search path (validate_search=FALSE), as we may assume that the overfitting in search is negligible (see more about this in McLatchie et al. (2023)).

For ordinal models we can use either latent projection approach (Catalina, Bürkner, and Vehtari 2021) or augmented-data projection (Weber, Glass, and Vehtari 2025). The augmented-data projection is more accurate, but much slower than latent projection. It is a good idea to first use latent projection and augmented-data projection can then be run with smaller nterms_max chosen based on the latent projection result.

We first use the latent projection.

fito_cv_latent <- cv_varsel(fito,
                      latent = TRUE,
                      method = "forward",
                      cv_method = "loo",
                      validate_search = FALSE)

We look at the estimated predictive performance of smaller models compared to the full model.

plot(fito_cv_latent, stats = c("elpd"), delta = TRUE)

We then repat using augmented-data projection and nterms_max=5.

fito_cv <- cv_varsel(fito,
                     nterms_max = 5,
                     method = "forward",
                     cv_method = "loo",
                     validate_search = FALSE)

In this case, there is no difference in the predictor ordering, but

plot(fito_cv, stats = c("elpd"), delta = TRUE)

Three or four variables seem to be needed to get the same performance as the full model. As the estimated predictive performance is not going much above the reference model performance, we know that the use of option validate_search=FALSE was safe (see more in McLatchie et al. (2023)).

We can get a loo-cv based recommendation for the model size to choose.

(nsel <- suggest_size(fito_cv, alpha = 0.1))
[1] 4
(vsel <- ranking(fito_cv, nterms_max = nsel)$fulldata)
[1] "alcohol"              "volatile.acidity"     "sulphates"           
[4] "total.sulfur.dioxide"

projpred recommends to use four variables: alcohol, volatile.acidity, sulphates, and chlorides.

4.1 Projected posterior

Next we form the projected posterior for the chosen model. This projected model can be used in the future to make predictions by using only the selected variables.

projo <- project(fito_cv, nterms = nsel, ndraws = 400)

The marginals of projected posteriors look like this.

projdraws <- as_draws_df(projo) |>
  subset_draws(variable = paste0("b_", vsel)) |>
  set_variables(variable = vsel)
mcmc_areas(projdraws,
           prob_outer = 0.99,
           area_method = "scaled height")

4.2 Predicted qualities

We can examine how well we can actually predict the wine quality based on these predictors.

preds <- proj_predict(projo) + 2

We can compare the predictive means and observed qualities.

pred_mean <- colMeans(preds)
ggplot(data.frame(observed = factor(wine$quality),
                  predicted_mean = pred_mean),
       aes(x = observed, y = predicted_mean)) +
  geom_swarm(color = "steelblue") +
  annotate("segment", x = 1, xend = 6, y = 3, yend = 8,
           linetype = "dashed", color = "gray50") +
  labs(x = "Observed quality", y = "Posterior predictive mean") +
  scale_y_continuous(breaks = 3:8, limits = c(3, 8))

Next we compute the average predicted probability for each quality category, grouped by the observed quality. This shows the full predictive distribution rather than just the mean.

qlevels <- sort(unique(wine$quality))
prob_df <- do.call(rbind, lapply(seq_along(wine$quality), function(i) {
  probs <- table(factor(preds[, i], levels = qlevels)) / nrow(preds)
  data.frame(observed = wine$quality[i],
             predicted = as.integer(names(probs)),
             prob = as.numeric(probs))
}))
prob_avg <- aggregate(prob ~ observed + predicted,
                      data = prob_df,
                      FUN = mean)
ggplot(prob_avg, aes(x = factor(observed), y = factor(predicted),
                     fill = prob)) +
  geom_tile() +
  geom_text(aes(label = round(prob, 2)), size = 4) +
  scale_fill_gradient(low = "white", high = "#2166AC",
                      name = "Probability") +
  labs(x = "Observed quality", y = "Predicted quality") +
  coord_equal() +
  annotate("segment", x = .5, xend = 6.5, y = .5, yend = 6.5,
           linetype = "dashed", alpha = 0.3)

We can predict something, but there is plenty of unexplained variation, which makes sense considering the available predictors. The model distinguishes well between low (5) and high (7–8) quality wines, but there is substantial overlap in the middle categories.

4.3 Predicted probabilities

Instead of discrete posterior predictive draws, we can use proj_linpred with transform=TRUE to obtain predictive probabilities for each ranking. This provides smoother and more informative summaries.

ppreds <- proj_linpred(projo, transform = TRUE)$pred

ppreds$pred is a 3D array with dimensions (draws × observations × categories). We average over projected draws to get the mean predicted probability for each observation and category.

qlevels <- 3:8
mean_probs <- apply(ppreds, c(2, 3), mean)
colnames(mean_probs) <- qlevels

We compute the expected quality as the probability-weighted mean of the quality levels. Compared to the posterior predictive mean from discrete draws, this gives a smoother prediction.

exp_quality <- as.numeric(mean_probs %*% qlevels)
ggplot(data.frame(observed = factor(wine$quality),
                  expected = exp_quality),
       aes(x = observed, y = expected)) +
  geom_swarm(color = "steelblue") +
  annotate("segment", x = 1, xend = 6, y = 3, yend = 8,
           linetype = "dashed", color = "gray50") +
  labs(x = "Observed quality",
       y = "Expected quality (probability-weighted)") +
  scale_y_continuous(breaks = 3:8, limits = c(3, 8))

We can also compute the average predicted probability for each quality category grouped by observed quality. This is the probability-based analog of the confusion matrix.

prob_by_obs <- do.call(rbind, lapply(qlevels, function(q) {
  idx <- wine$quality == q
  if (sum(idx) == 0) return(NULL)
  data.frame(observed = q,
             predicted = qlevels,
             prob = colMeans(mean_probs[idx, ]))
}))
ggplot(prob_by_obs, aes(x = factor(observed), y = factor(predicted),
                        fill = prob)) +
  geom_tile() +
  geom_text(aes(label = sprintf("%.2f", prob)), size = 4) +
  scale_fill_gradient(low = "white", high = "#2166AC",
                      name = "Probability") +
  labs(x = "Observed quality", y = "Predicted quality") +
  coord_equal()

Finally, we visualize the full predicted probability distribution as stacked bars grouped by observed quality. This is alternative way to show how probability is disyributed across for each observed quality level.

ggplot(prob_by_obs, aes(x = factor(observed), y = prob,
                        fill = factor(predicted))) +
  geom_col(position = position_stack(reverse = TRUE)) +
  scale_fill_brewer(palette = "RdYlBu", direction = -1,
                    name = "Predicted\nquality") +
  guides(fill = guide_legend(reverse = TRUE)) +
  labs(x = "Observed quality", y = "Average predicted probability") +
  scale_y_continuous(expand = expansion(mult = c(0, 0.02)))

The probability-based visualizations confirm that the model can separate low from high quality wines, but assigns substantial probability to neighboring categories, reflecting genuine uncertainty in the predictions.

References

Catalina, Alejandro, Paul Bürkner, and Aki Vehtari. 2021. “Latent Space Projection Predictive Inference.” arXiv Preprint arXiv:2109.04702.
McLatchie, Yann, Sölvi Rögnvaldsson, Frank Weber, and Aki Vehtari. 2023. “Robust and Efficient Projection Predictive Inference.” arXiv Preprint arXiv:2306.15581.
Piironen, Juho, Markus Paasiniemi, and Aki Vehtari. 2020. “Projective Inference in High-Dimensional Problems: Prediction and Feature Selection.” Electronic Journal of Statistics 14 (1): 2155–97.
Piironen, Juho, and Aki Vehtari. 2017. “Comparison of Bayesian Predictive Methods for Model Selection.” Statistics and Computing 27 (3): 711–35. https://doi.org/10.1007/s11222-016-9649-y.
Säilynoja, Teemu, Andrew R Johnson, Osvaldo A Martin, and Aki Vehtari. 2025. “Recommendations for Visual Predictive Checks in Bayesian Workflow.” arXiv Preprint arXiv:2503.01509.
Vehtari, Aki, Andrew Gelman, and Jonah Gabry. 2017. “Practical Bayesian Model Evaluation Using Leave-One-Out Cross-Validation and WAIC.” Statistics and Computing 27 (5): 1413–32. https://doi.org/10.1007/s11222-016-9696-4.
Weber, Frank, Änne Glass, and Aki Vehtari. 2025. “Projection Predictive Variable Selection for Discrete Response Families with Finite Support.” Computational Statistics 40 (2): 701–21.
Zhang, Yan Dora, Brian P. Naughton, Howard D. Bondell, and Brian J. Reich. 2022. “Bayesian Regression Using a Prior on the Model Fit: The R2-D2 Shrinkage Prior.” Journal of the American Statistical Association 117: 862–74.

Licenses

  • Code © 2017-2026, Aki Vehtari, licensed under BSD-3.
  • Text © 2017-2026, Aki Vehtari, licensed under CC-BY-NC 4.0.