The marginaleffects
package offers convenience functions to compute and display predictions, contrasts, and marginal effects from bayesian models estimated by the brms
package. To compute these quantities, marginaleffects
relies on workshorse functions from the brms
package to draw from the posterior distribution. The type of draws used is controlled by using the type
argument of the predictions
or marginaleffects
functions:
type = "response"
: Compute posterior draws of the expected value using the brms::posterior_epred
function.type = "link"
: Compute posterior draws of the linear predictor using the brms::posterior_linpred
function.type = "prediction"
: Compute posterior draws of the posterior predictive distribution using the brms::posterior_predict
function.The predictions
and marginaleffects
functions can also pass additional arguments to the brms
prediction functions via the ...
ellipsis. For example, if mod
is a mixed-effects model, then this command will compute 10 draws from the posterior predictive distribution, while ignoring all group-level effects:
See the brms
documentation for a list of available arguments:
Load libraries and download data on passengers of the Titanic from the Rdatasets archive:
library(marginaleffects)
library(brms)
library(ggplot2)
library(ggdist)
library(magrittr)
dat <- read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/carData/TitanicSurvival.csv")
dat$survived <- ifelse(dat$survived == "yes", 1, 0)
dat$woman <- ifelse(dat$sex == "female", 1, 0)
Fit a logit model with a multiplicative interaction:
We can compute adjusted predicted values of the outcome variable (i.e., probability of survival aboard the Titanic) using the predictions
function. By default, this function calculates predictions for each row of the dataset:
pred <- predictions(mod)
head(pred)
#> rowid type predicted survived woman age passengerClass conf.low conf.high
#> 1 1 response 0.9363293 1 1 29.0000 1st 0.9098216 0.9611878
#> 2 2 response 0.8473119 1 0 0.9167 1st 0.7602898 0.9255607
#> 3 3 response 0.9421204 0 1 2.0000 1st 0.8972346 0.9729983
#> 4 4 response 0.5129703 0 0 30.0000 1st 0.4317595 0.5960417
#> 5 5 response 0.9370423 0 1 25.0000 1st 0.9088370 0.9609094
#> 6 6 response 0.2731266 1 0 48.0000 1st 0.1997047 0.3484897
To visualize the relationship between the outcome and one of the regressors, we can plot conditional adjusted predictions with the plot_cap
function:
Compute adjusted predictions for some user-specified values of the regressors, using the newdata
argument and the datagrid
function:
pred <- predictions(mod,
newdata = datagrid(woman = 0:1,
passengerClass = c("1st", "2nd", "3rd")))
pred
#> rowid type predicted age woman passengerClass conf.low conf.high
#> 1 1 response 0.51472830 29.88113 0 1st 0.43368752 0.5986352
#> 2 2 response 0.93614415 29.88113 1 1st 0.90930611 0.9605184
#> 3 3 response 0.20271297 29.88113 0 2nd 0.15011114 0.2585327
#> 4 4 response 0.77885268 29.88113 1 2nd 0.71264810 0.8359897
#> 5 5 response 0.08758891 29.88113 0 3rd 0.06560364 0.1148385
#> 6 6 response 0.57199879 29.88113 1 3rd 0.49642230 0.6461308
The posteriordraws
function samples from the posterior distribution of the model, and produces a data frame with drawid
and draw
columns.
pred <- posteriordraws(pred)
head(pred)
#> rowid type drawid draw age woman passengerClass
#> 1 1 response 1 0.5163038 29.88113 0 1st
#> 2 1 response 2 0.5340073 29.88113 0 1st
#> 3 1 response 3 0.5107874 29.88113 0 1st
#> 4 1 response 4 0.5086535 29.88113 0 1st
#> 5 1 response 5 0.5982841 29.88113 0 1st
#> 6 1 response 6 0.5252051 29.88113 0 1st
This “long” format makes it easy to plots results:
ggplot(pred, aes(x = draw, fill = factor(woman))) +
geom_density() +
facet_grid(~ passengerClass, labeller = label_both) +
labs(x = "Predicted probability of survival", y = "", fill = "Woman")
Use marginaleffects()
to compute marginal effects (slopes of the regression equation) for each row of the dataset, and use summary()
to compute “Average Marginal Effects”, that is, the average of all observation-level marginal effects:
mfx <- marginaleffects(mod)
summary(mfx)
#> Average marginal effects
#> Term Contrast Effect 2.5 % 97.5 %
#> 1 woman dY/dX 0.366002 0.338048 0.393182
#> 2 age dY/dX -0.005239 -0.007038 -0.003495
#> 3 passengerClass 2nd - 1st -0.236254 -0.306317 -0.166685
#> 4 passengerClass 3rd - 1st -0.387400 -0.454317 -0.321181
#>
#> Model type: brmsfit
#> Prediction type: response
Compute marginal effects with some regressors fixed at user-specified values, and other regressors held at their means:
marginaleffects(mod,
newdata = datagrid(woman = 1,
passengerClass = "1st"))
#> rowid type term contrast dydx conf.low conf.high age woman
#> 1 1 response woman dY/dX 0.1564433778 0.109946544 0.206274493 29.88113 1
#> 2 1 response age dY/dX -0.0002116438 -0.001443799 0.000902002 29.88113 1
#> 3 1 response passengerClass 2nd - 1st -0.1569909361 -0.216442214 -0.100785866 29.88113 1
#> 4 1 response passengerClass 3rd - 1st -0.3638107510 -0.436897055 -0.290797268 29.88113 1
#> passengerClass
#> 1 1st
#> 2 1st
#> 3 1st
#> 4 1st
Compute and plot conditional marginal effects:
The posteriordraws
produces a dataset with drawid
and draw
columns:
draws <- posteriordraws(mfx)
dim(draws)
#> [1] 16736000 10
head(draws)
#> rowid type term contrast drawid draw survived woman age passengerClass
#> 1 1 response age dY/dX 1 0.0000882587 1 1 29 1st
#> 2 1 response age dY/dX 2 0.0004841800 1 1 29 1st
#> 3 1 response age dY/dX 3 0.0002124318 1 1 29 1st
#> 4 1 response age dY/dX 4 -0.0007529172 1 1 29 1st
#> 5 1 response age dY/dX 5 -0.0007341808 1 1 29 1st
#> 6 1 response age dY/dX 6 0.0003494105 1 1 29 1st
We can use this dataset to plot our results. For example, to plot the posterior density of the marginal effect of age
when the woman
variable is equal to 0 or 1:
mfx <- marginaleffects(mod,
variables = "age",
newdata = datagrid(woman = 0:1)) |>
posteriordraws()
ggplot(mfx, aes(x = draw, fill = factor(woman))) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Marginal Effect of Age on Survival",
y = "Posterior density",
fill = "Woman")
This section replicates some of the analyses of a random effects model published in Andrew Heiss’ blog post: “A guide to correctly calculating posterior predictions and average marginal effects with multilievel Bayesian models.” The objective is mainly to illustrate the use of marginaleffects
. Please refer to the original post for a detailed discussion of the quantities computed below.
Load libraries and download data:
library(brms)
library(ggdist)
library(patchwork)
library(marginaleffects)
vdem_2015 <- read.csv("https://github.com/vincentarelbundock/marginaleffects/raw/main/data-raw/vdem_2015.csv")
head(vdem_2015)
#> country_name country_text_id year region media_index party_autonomy_ord
#> 1 Mexico MEX 2015 Latin America and the Caribbean 0.837 3
#> 2 Suriname SUR 2015 Latin America and the Caribbean 0.883 4
#> 3 Sweden SWE 2015 Western Europe and North America 0.956 4
#> 4 Switzerland CHE 2015 Western Europe and North America 0.939 4
#> 5 Ghana GHA 2015 Sub-Saharan Africa 0.858 4
#> 6 South Africa ZAF 2015 Sub-Saharan Africa 0.898 4
#> polyarchy civil_liberties party_autonomy
#> 1 0.631 0.704 TRUE
#> 2 0.777 0.887 TRUE
#> 3 0.915 0.968 TRUE
#> 4 0.901 0.960 TRUE
#> 5 0.724 0.921 TRUE
#> 6 0.752 0.869 TRUE
Fit a basic model:
mod <- brm(
bf(media_index ~ party_autonomy + civil_liberties + (1 | region),
phi ~ (1 | region)),
data = vdem_2015,
family = Beta(),
control = list(adapt_delta = 0.9))
To compute posterior predictions for specific values of the regressors, we use the newdata
argument and the datagrid
function. We also use the type
argument to compute two types of predictions: accounting for residual (observation-level) residual variance (prediction
) or ignoring it (response
).
nd = datagrid(model = mod,
party_autonomy = c(TRUE, FALSE),
civil_liberties = .5,
region = "Middle East and North Africa")
p1 <- predictions(mod, type = "response", newdata = nd) %>%
posteriordraws()
p2 <- predictions(mod, type = "prediction", newdata = nd) %>%
posteriordraws()
pred <- rbind(p1, p2)
Extract posterior draws and plot them:
ggplot(pred, aes(x = draw, fill = party_autonomy)) +
stat_halfeye(alpha = .5) +
facet_wrap(~ type) +
labs(x = "Media index (predicted)",
y = "Posterior density",
fill = "Party autonomy")
As noted in the Marginal Effects vignette, there should be one distinct marginal effect for each combination of regressor values. Here, we consider only one combination of regressor values, where region
is “Middle East and North Africa”, and civil_liberties
is 0.5. Then, we calculate the mean of the posterior distribution of marginal effects:
mfx <- marginaleffects(mod,
newdata = datagrid(civil_liberties = .5,
region = "Middle East and North Africa"))
mfx
#> rowid type term contrast dydx conf.low conf.high party_autonomy
#> 1 1 response party_autonomy TRUE - FALSE 0.2514419 0.1727092 0.3378497 TRUE
#> 2 1 response civil_liberties dY/dX 0.8161557 0.6245150 1.0072559 TRUE
#> civil_liberties region
#> 1 0.5 Middle East and North Africa
#> 2 0.5 Middle East and North Africa
Use the posteriordraws()
to extract draws from the posterio distribution of marginal effects, and plot them:
mfx <- posteriordraws(mfx)
ggplot(mfx, aes(x = draw, y = term)) +
stat_halfeye() +
labs(x = "Marginal effect", y = "")
Plot marginal effects, conditional on a regressor:
pred <- predictions(mod,
newdata = datagrid(party_autonomy = FALSE,
region = "Middle East and North Africa",
civil_liberties = seq(0, 1, by = 0.05))) |>
posteriordraws()
ggplot(pred, aes(x = civil_liberties, y = draw)) +
stat_lineribbon() +
scale_fill_brewer(palette = "Reds") +
labs(x = "Civil liberties",
y = "Media index (predicted)",
fill = "")
The slope of this line for different values of civil liberties can be obtained with:
mfx <- marginaleffects(mod,
newdata = datagrid(civil_liberties = c(.2, .5, .8),
party_autonomy = FALSE,
region = "Middle East and North Africa"),
variables = "civil_liberties")
mfx
#> rowid type term dydx conf.low conf.high civil_liberties party_autonomy
#> 1 1 response civil_liberties 0.4889313 0.3655136 0.6382760 0.2 FALSE
#> 2 2 response civil_liberties 0.8073257 0.6130583 1.0031587 0.5 FALSE
#> 3 3 response civil_liberties 0.8065424 0.6712491 0.9274168 0.8 FALSE
#> region
#> 1 Middle East and North Africa
#> 2 Middle East and North Africa
#> 3 Middle East and North Africa
And plotted:
mfx <- posteriordraws(mfx)
ggplot(mfx, aes(x = draw, fill = factor(civil_liberties))) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Marginal effect of Civil Liberties on Media Index",
y = "Posterior density",
fill = "Civil liberties")
The marginaleffects
function can use the ellipsis (...
) to push any argument forward to the posterior_predict
function. This can alter the types of predictions returned. For example, the re_formula=NA
argument of the posterior_predict.brmsfit
method will compute marginaleffects without including any group-level effects:
mfx <- marginaleffects(mod,
newdata = datagrid(civil_liberties = c(.2, .5, .8),
party_autonomy = FALSE,
region = "Middle East and North Africa"),
variables = "civil_liberties",
re_formula = NA) |>
posteriordraws()
ggplot(mfx, aes(x = draw, fill = factor(civil_liberties))) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Marginal effect of Civil Liberties on Media Index",
y = "Posterior density",
fill = "Civil liberties")
pred <- predictions(mod,
re_formula = NA,
newdata = datagrid(party_autonomy = c(TRUE, FALSE))) |>
posteriordraws()
mfx <- marginaleffects(mod,
re_formula = NA,
variables = "party_autonomy") |>
posteriordraws()
plot1 <- ggplot(pred, aes(x = draw, fill = party_autonomy)) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Media index (Predicted)",
y = "Posterior density",
fill = "Party autonomy")
plot2 <- ggplot(mfx, aes(x = draw)) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Contrast: Party autonomy TRUE - FALSE",
y = "",
fill = "Party autonomy")
# combine plots using the `patchwork` package
plot1 + plot2
Predicted media index by region and level of civil liberties:
pred <- predictions(mod,
newdata = datagrid(region = vdem_2015$region,
party_autonomy = FALSE,
civil_liberties = seq(0, 1, length.out = 100))) |>
posteriordraws()
ggplot(pred, aes(x = civil_liberties, y = draw)) +
stat_lineribbon() +
scale_fill_brewer(palette = "Reds") +
facet_wrap(~ region) +
labs(x = "Civil liberties",
y = "Media index (predicted)",
fill = "")
Predicted media index by region and level of civil liberties:
pred <- predictions(mod,
newdata = datagrid(region = vdem_2015$region,
civil_liberties = c(.2, .8),
party_autonomy = FALSE)) |>
posteriordraws()
ggplot(pred, aes(x = draw, fill = factor(civil_liberties))) +
stat_halfeye(slab_alpha = .5) +
facet_wrap(~ region) +
labs(x = "Media index (predicted)",
y = "Posterior density",
fill = "Civil liberties")
Predicted media index by region and party autonomy:
pred <- predictions(mod,
newdata = datagrid(region = vdem_2015$region,
party_autonomy = c(TRUE, FALSE),
civil_liberties = .5)) |>
posteriordraws()
ggplot(pred, aes(x = draw, y = region , fill = party_autonomy)) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Media index (predicted)",
y = "",
fill = "Party autonomy")
TRUE/FALSE contrasts (marginal effects) of party autonomy by region:
mfx <- marginaleffects(mod,
variables = "party_autonomy",
newdata = datagrid(region = vdem_2015$region,
civil_liberties = .5)) |>
posteriordraws()
ggplot(mfx, aes(x = draw, y = region , fill = party_autonomy)) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Media index (predicted)",
y = "",
fill = "Party autonomy")
We can also obtain predictions or marginal effects for a hypothetical group instead of one of the observed regions. To achieve this, we create a dataset with NA
in the region
column. Then we call the marginaleffects
or predictions
functions with the allow_new_levels
argument. This argument is pushed through via the ellipsis (...
) to the posterior_epred
function of the brms
package:
dat <- data.frame(civil_liberties = .5,
party_autonomy = FALSE,
region = "New Region")
mfx <- marginaleffects(
mod,
variables = "party_autonomy",
allow_new_levels = TRUE,
newdata = dat)
draws <- posteriordraws(mfx)
ggplot(draws, aes(x = draw)) +
stat_halfeye() +
labs(x = "Marginal effect of party autonomy in a generic world region", y = "")
Fit a model with categorical outcome (heating system choice in California houses) and logit link:
dat <- "https://vincentarelbundock.github.io/Rdatasets/csv/Ecdat/Heating.csv"
dat <- read.csv(dat)
mod <- brm(depvar ~ ic.gc + oc.gc,
data = dat,
family = categorical(link = "logit"))
Compute predicted probabilities for each level of the outcome variable:
pred <- predictions(mod)
head(pred)
#> rowid type group predicted depvar ic.gc oc.gc conf.low conf.high
#> 1 1 response ec 0.06650248 gc 866.00 199.69 0.04444724 0.09105698
#> 2 2 response ec 0.07698538 gc 727.93 168.66 0.05963308 0.09663204
#> 3 3 response ec 0.10419361 gc 599.48 165.58 0.06036268 0.15931857
#> 4 4 response ec 0.06352814 er 835.17 180.88 0.04593506 0.08205828
#> 5 5 response ec 0.07480561 er 755.59 174.91 0.05786132 0.09305666
#> 6 6 response ec 0.07151562 gc 666.11 135.67 0.04472428 0.10273280
Extract posterior draws and plot them:
draws <- posteriordraws(pred)
ggplot(draws, aes(x = draw, fill = group)) +
geom_density(alpha = .2, color = "white") +
labs(x = "Predicted probability",
y = "Density",
fill = "Heating system")
Use the plot_cap
function to plot conditional adjusted predictions for each level of the outcome variable gear
, conditional on the value of the mpg
regressor:
mfx <- marginaleffects(mod)
summary(mfx)
#> Average marginal effects
#> Group Term Effect 2.5 % 97.5 %
#> 1 ec ic.gc -1.826e-04 -4.093e-04 3.378e-05
#> 2 ec oc.gc 4.875e-04 -4.283e-04 1.476e-03
#> 3 er ic.gc 1.673e-05 -2.130e-04 2.408e-04
#> 4 er oc.gc -1.025e-03 -2.005e-03 1.385e-05
#> 5 gc ic.gc 1.235e-05 -3.674e-04 3.875e-04
#> 6 gc oc.gc 1.057e-03 -5.282e-04 2.791e-03
#> 7 gr ic.gc 4.252e-05 -2.575e-04 3.130e-04
#> 8 gr oc.gc 7.841e-05 -1.191e-03 1.299e-03
#> 9 hp ic.gc 1.110e-04 -6.588e-05 2.983e-04
#> 10 hp oc.gc -5.973e-04 -1.420e-03 2.241e-04
#>
#> Model type: brmsfit
#> Prediction type: response