Bayesian analysis with brms

The marginaleffects package offers convenience functions to compute and display predictions, contrasts, and marginal effects from bayesian models estimated by the brms package. To compute these quantities, marginaleffects relies on workshorse functions from the brms package to draw from the posterior distribution. The type of draws used is controlled by using the type argument of the predictions or marginaleffects functions:

The predictions and marginaleffects functions can also pass additional arguments to the brms prediction functions via the ... ellipsis. For example, if mod is a mixed-effects model, then this command will compute 10 draws from the posterior predictive distribution, while ignoring all group-level effects:

predictions(mod, type = "prediction", ndraws = 10, re_formula = NA)

See the brms documentation for a list of available arguments:

?brms::posterior_epred
?brms::posterior_linpred
?brms::posterior_predict

Logistic regression with multiplicative interactions

Load libraries and download data on passengers of the Titanic from the Rdatasets archive:

library(marginaleffects)
library(brms)
library(ggplot2)
library(ggdist)
library(magrittr)

dat <- read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/carData/TitanicSurvival.csv")
dat$survived <- ifelse(dat$survived == "yes", 1, 0)
dat$woman <- ifelse(dat$sex == "female", 1, 0)

Fit a logit model with a multiplicative interaction:

mod <- brm(survived ~ woman * age + passengerClass,
           family = bernoulli(link = "logit"),
           data = dat)

Adjusted predictions

We can compute adjusted predicted values of the outcome variable (i.e., probability of survival aboard the Titanic) using the predictions function. By default, this function calculates predictions for each row of the dataset:

To visualize the relationship between the outcome and one of the regressors, we can plot conditional adjusted predictions with the plot_cap function:

Compute adjusted predictions for some user-specified values of the regressors, using the newdata argument and the datagrid function:

The posteriordraws function samples from the posterior distribution of the model, and produces a data frame with drawid and draw columns.

This “long” format makes it easy to plots results:

Marginal effects

Use marginaleffects() to compute marginal effects (slopes of the regression equation) for each row of the dataset, and use summary() to compute “Average Marginal Effects”, that is, the average of all observation-level marginal effects:

Compute marginal effects with some regressors fixed at user-specified values, and other regressors held at their means:

Compute and plot conditional marginal effects:

The posteriordraws produces a dataset with drawid and draw columns:

We can use this dataset to plot our results. For example, to plot the posterior density of the marginal effect of age when the woman variable is equal to 0 or 1:

Random effects model

This section replicates some of the analyses of a random effects model published in Andrew Heiss’ blog post: “A guide to correctly calculating posterior predictions and average marginal effects with multilievel Bayesian models.” The objective is mainly to illustrate the use of marginaleffects. Please refer to the original post for a detailed discussion of the quantities computed below.

Load libraries and download data:

library(brms)
library(ggdist)
library(patchwork)
library(marginaleffects)

vdem_2015 <- read.csv("https://github.com/vincentarelbundock/marginaleffects/raw/main/data-raw/vdem_2015.csv")

head(vdem_2015)
#>   country_name country_text_id year                           region media_index party_autonomy_ord
#> 1       Mexico             MEX 2015  Latin America and the Caribbean       0.837                  3
#> 2     Suriname             SUR 2015  Latin America and the Caribbean       0.883                  4
#> 3       Sweden             SWE 2015 Western Europe and North America       0.956                  4
#> 4  Switzerland             CHE 2015 Western Europe and North America       0.939                  4
#> 5        Ghana             GHA 2015               Sub-Saharan Africa       0.858                  4
#> 6 South Africa             ZAF 2015               Sub-Saharan Africa       0.898                  4
#>   polyarchy civil_liberties party_autonomy
#> 1     0.631           0.704           TRUE
#> 2     0.777           0.887           TRUE
#> 3     0.915           0.968           TRUE
#> 4     0.901           0.960           TRUE
#> 5     0.724           0.921           TRUE
#> 6     0.752           0.869           TRUE

Fit a basic model:

mod <- brm(
  bf(media_index ~ party_autonomy + civil_liberties + (1 | region),
     phi ~ (1 | region)),
  data = vdem_2015,
  family = Beta(),
  control = list(adapt_delta = 0.9))

Posterior predictions

To compute posterior predictions for specific values of the regressors, we use the newdata argument and the datagrid function. We also use the type argument to compute two types of predictions: accounting for residual (observation-level) residual variance (prediction) or ignoring it (response).

Extract posterior draws and plot them:

Marginal effects and contrasts

As noted in the Marginal Effects vignette, there should be one distinct marginal effect for each combination of regressor values. Here, we consider only one combination of regressor values, where region is “Middle East and North Africa”, and civil_liberties is 0.5. Then, we calculate the mean of the posterior distribution of marginal effects:

Use the posteriordraws() to extract draws from the posterio distribution of marginal effects, and plot them:

Plot marginal effects, conditional on a regressor:

Continuous predictors

The slope of this line for different values of civil liberties can be obtained with:

And plotted:

The marginaleffects function can use the ellipsis (...) to push any argument forward to the posterior_predict function. This can alter the types of predictions returned. For example, the re_formula=NA argument of the posterior_predict.brmsfit method will compute marginaleffects without including any group-level effects:

Global grand mean

Region-specific predictions and contrasts

Predicted media index by region and level of civil liberties:

Predicted media index by region and level of civil liberties:

Predicted media index by region and party autonomy:

TRUE/FALSE contrasts (marginal effects) of party autonomy by region:

Hypothetical groups

We can also obtain predictions or marginal effects for a hypothetical group instead of one of the observed regions. To achieve this, we create a dataset with NA in the region column. Then we call the marginaleffects or predictions functions with the allow_new_levels argument. This argument is pushed through via the ellipsis (...) to the posterior_epred function of the brms package:

Multinomial logit

Fit a model with categorical outcome (heating system choice in California houses) and logit link:

dat <- "https://vincentarelbundock.github.io/Rdatasets/csv/Ecdat/Heating.csv"
dat <- read.csv(dat)

mod <- brm(depvar ~ ic.gc + oc.gc,
           data = dat,
           family = categorical(link = "logit"))

Adjusted predictions

Compute predicted probabilities for each level of the outcome variable:

Extract posterior draws and plot them:

Use the plot_cap function to plot conditional adjusted predictions for each level of the outcome variable gear, conditional on the value of the mpg regressor:

Marginal effects