predictions(mod, type = "prediction", ndraws = 10, re_formula = NA)
15 Bayes
The marginaleffects
package offers convenience functions to compute and display predictions, contrasts, and marginal effects from bayesian models estimated by the brms
package. To compute these quantities, marginaleffects
relies on workhorse functions from the brms
package to draw from the posterior distribution. The type of draws used is controlled by using the type
argument of the predictions
or slopes
functions:
-
type = "response"
: Compute posterior draws of the expected value using thebrms::posterior_epred
function. -
type = "link"
: Compute posterior draws of the linear predictor using thebrms::posterior_linpred
function. -
type = "prediction"
: Compute posterior draws of the posterior predictive distribution using thebrms::posterior_predict
function.
The predictions
and slopes
functions can also pass additional arguments to the brms
prediction functions via the ...
ellipsis. For example, if mod
is a mixed-effects model, then this command will compute 10 draws from the posterior predictive distribution, while ignoring all group-level effects:
See the brms
documentation for a list of available arguments:
?brms::posterior_epred
?brms::posterior_linpred
?brms::posterior_predict
15.1 Logistic regression with multiplicative interactions
Load libraries and download data on passengers of the Titanic from the Rdatasets archive:
Fit a logit model with a multiplicative interaction:
15.1.1 Adjusted predictions
We can compute adjusted predicted values of the outcome variable (i.e., probability of survival aboard the Titanic) using the predictions
function. By default, this function calculates predictions for each row of the dataset:
predictions(mod)
To visualize the relationship between the outcome and one of the regressors, we can plot conditional adjusted predictions with the plot_predictions
function:
plot_predictions(mod, condition = "age")
Compute adjusted predictions for some user-specified values of the regressors, using the newdata
argument and the datagrid
function:
pred <- predictions(mod,
newdata = datagrid(woman = 0:1,
passengerClass = c("1st", "2nd", "3rd")))
pred
The get_draws
function samples from the posterior distribution of the model, and produces a data frame with drawid
and draw
columns.
This “long” format makes it easy to plots results:
ggplot(pred, aes(x = draw, fill = factor(woman))) +
geom_density() +
facet_grid(~ passengerClass, labeller = label_both) +
labs(x = "Predicted probability of survival", y = "", fill = "Woman")
15.1.2 Marginal effects
Use slopes()
to compute marginal effects (slopes of the regression equation) for each row of the dataset, and use )
to compute “Average Marginal Effects”, that is, the average of all observation-level marginal effects:
mfx <- slopes(mod)
mfx
Compute marginal effects with some regressors fixed at user-specified values, and other regressors held at their means:
Compute and plot conditional marginal effects:
plot_slopes(mod, variables = "woman", condition = "age")
The get_draws
produces a dataset with drawid
and draw
columns:
We can use this dataset to plot our results. For example, to plot the posterior density of the marginal effect of age
when the woman
variable is equal to 0 or 1:
15.2 Random effects model
This section replicates some of the analyses of a random effects model published in Andrew Heiss’ blog post: “A guide to correctly calculating posterior predictions and average marginal effects with multilevel Bayesian models.” The objective is mainly to illustrate the use of marginaleffects
. Please refer to the original post for a detailed discussion of the quantities computed below.
Load libraries and download data:
Fit a basic model:
15.2.1 Posterior predictions
To compute posterior predictions for specific values of the regressors, we use the newdata
argument and the datagrid
function. We also use the type
argument to compute two types of predictions: accounting for residual (observation-level) residual variance (prediction
) or ignoring it (response
).
nd = datagrid(model = mod,
party_autonomy = c(TRUE, FALSE),
civil_liberties = .5,
region = "Middle East and North Africa")
p1 <- predictions(mod, type = "response", newdata = nd) |>
get_draws() |>
transform(type = "Response")
p2 <- predictions(mod, type = "prediction", newdata = nd) |>
get_draws() |>
transform(type = "Prediction")
pred <- rbind(p1, p2)
Extract posterior draws and plot them:
ggplot(pred, aes(x = draw, fill = party_autonomy)) +
stat_halfeye(alpha = .5) +
facet_wrap(~ type) +
labs(x = "Media index (predicted)",
y = "Posterior density",
fill = "Party autonomy")
15.2.2 Marginal effects and contrasts
As noted in the Marginal Effects vignette, there should be one distinct marginal effect for each combination of regressor values. Here, we consider only one combination of regressor values, where region
is “Middle East and North Africa”, and civil_liberties
is 0.5. Then, we calculate the mean of the posterior distribution of marginal effects:
Use the get_draws()
to extract draws from the posterior distribution of marginal effects, and plot them:
mfx <- get_draws(mfx)
ggplot(mfx, aes(x = draw, y = term)) +
stat_halfeye() +
labs(x = "Marginal effect", y = "")
Plot marginal effects, conditional on a regressor:
plot_slopes(mod,
variables = "civil_liberties",
condition = "party_autonomy")
15.2.3 Continuous predictors
pred <- predictions(mod,
newdata = datagrid(party_autonomy = FALSE,
region = "Middle East and North Africa",
civil_liberties = seq(0, 1, by = 0.05))) |>
get_draws()
ggplot(pred, aes(x = civil_liberties, y = draw)) +
stat_lineribbon() +
scale_fill_brewer(palette = "Reds") +
labs(x = "Civil liberties",
y = "Media index (predicted)",
fill = "")
The slope of this line for different values of civil liberties can be obtained with:
And plotted:
The slopes
function can use the ellipsis (...
) to push any argument forward to the posterior_predict
function. This can alter the types of predictions returned. For example, the re_formula=NA
argument of the posterior_predict.brmsfit
method will compute marginaleffects without including any group-level effects:
mfx <- slopes(
mod,
newdata = datagrid(
civil_liberties = c(.2, .5, .8),
party_autonomy = FALSE,
region = "Middle East and North Africa"),
variables = "civil_liberties",
re_formula = NA) |>
get_draws()
ggplot(mfx, aes(x = draw, fill = factor(civil_liberties))) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Marginal effect of Civil Liberties on Media Index",
y = "Posterior density",
fill = "Civil liberties")
15.2.4 Global grand mean
pred <- predictions(
mod,
re_formula = NA,
newdata = datagrid(party_autonomy = c(TRUE, FALSE))) |>
get_draws()
mfx <- slopes(
mod,
re_formula = NA,
variables = "party_autonomy") |>
get_draws()
plot1 <- ggplot(pred, aes(x = draw, fill = party_autonomy)) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Media index (Predicted)",
y = "Posterior density",
fill = "Party autonomy")
plot2 <- ggplot(mfx, aes(x = draw)) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Contrast: Party autonomy TRUE - FALSE",
y = "",
fill = "Party autonomy")
## combine plots using the `patchwork` package
plot1 + plot2
15.2.5 Region-specific predictions and contrasts
Predicted media index by region and level of civil liberties:
pred <- predictions(mod,
newdata = datagrid(region = vdem_2015$region,
party_autonomy = FALSE,
civil_liberties = seq(0, 1, length.out = 100))) |>
get_draws()
ggplot(pred, aes(x = civil_liberties, y = draw)) +
stat_lineribbon() +
scale_fill_brewer(palette = "Reds") +
facet_wrap(~ region) +
labs(x = "Civil liberties",
y = "Media index (predicted)",
fill = "")
Predicted media index by region and level of civil liberties:
pred <- predictions(mod,
newdata = datagrid(region = vdem_2015$region,
civil_liberties = c(.2, .8),
party_autonomy = FALSE)) |>
get_draws()
ggplot(pred, aes(x = draw, fill = factor(civil_liberties))) +
stat_halfeye(slab_alpha = .5) +
facet_wrap(~ region) +
labs(x = "Media index (predicted)",
y = "Posterior density",
fill = "Civil liberties")
Predicted media index by region and party autonomy:
pred <- predictions(mod,
newdata = datagrid(region = vdem_2015$region,
party_autonomy = c(TRUE, FALSE),
civil_liberties = .5)) |>
get_draws()
ggplot(pred, aes(x = draw, y = region , fill = party_autonomy)) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Media index (predicted)",
y = "",
fill = "Party autonomy")
TRUE/FALSE contrasts (marginal effects) of party autonomy by region:
mfx <- slopes(
mod,
variables = "party_autonomy",
newdata = datagrid(
region = vdem_2015$region,
civil_liberties = .5)) |>
get_draws()
ggplot(mfx, aes(x = draw, y = region , fill = party_autonomy)) +
stat_halfeye(slab_alpha = .5) +
labs(x = "Media index (predicted)",
y = "",
fill = "Party autonomy")
15.2.6 Hypothetical groups
We can also obtain predictions or marginal effects for a hypothetical group instead of one of the observed regions. To achieve this, we create a dataset with NA
in the region
column. Then we call the marginaleffects
or predictions
functions with the allow_new_levels
argument. This argument is pushed through via the ellipsis (...
) to the posterior_epred
function of the brms
package:
dat <- data.frame(civil_liberties = .5,
party_autonomy = FALSE,
region = "New Region")
mfx <- slopes(
mod,
variables = "party_autonomy",
allow_new_levels = TRUE,
newdata = dat)
draws <- get_draws(mfx)
ggplot(draws, aes(x = draw)) +
stat_halfeye() +
labs(x = "Marginal effect of party autonomy in a generic world region", y = "")
15.2.7 Averaging, marginalizing, integrating random effects
Consider a logistic regression model with random effects:
dat <- read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/plm/EmplUK.csv")
dat$x <- as.numeric(dat$output > median(dat$output))
dat$y <- as.numeric(dat$emp > median(dat$emp))
mod <- brm(y ~ x + (1 | firm), data = dat, backend = "cmdstanr", family = "bernoulli")
We can compute adjusted predictions for a given value of x
and for each firm (random effects) as follows:
p <- predictions(mod, newdata = datagrid(x = 0, firm = unique))
head(p)
We can average/marginalize/integrate across random effects with the avg_predictions()
function or the by
argument:
avg_predictions(mod, newdata = datagrid(x = 0, firm = unique))
predictions(mod, newdata = datagrid(x = 0:1, firm = unique), by = "x")
We can also draw from the (assumed gaussian) population distribution of random effects, by asking predictions()
to make predictions for new “levels” of the random effects. If we then take an average of predictions using avg_predictions()
or the by
argument, we will have “integrated out the random effects”, as described in the brmsmargins
package vignette. In the code below, we make predictions for 100 firm identifiers which were not in the original dataset. We also ask predictions()
to push forward the allow_new_levels
and sample_new_levels
arguments to the brms::posterior_epred
function:
predictions(
mod,
newdata = datagrid(x = 0:1, firm = -1:-100),
allow_new_levels = TRUE,
sample_new_levels = "gaussian",
by = "x")
We can “integrate out” random effects in the other slopes
functions too. For instance,
avg_comparisons(
mod,
newdata = datagrid(firm = -1:-100),
allow_new_levels = TRUE,
sample_new_levels = "gaussian")
This is nearly equivalent the brmsmargins
command output (with slight variations due to different random seeds):
library(brmsmargins)
bm <- brmsmargins(
k = 100,
object = mod,
at = data.frame(x = c(0, 1)),
CI = .95,
CIType = "ETI",
contrasts = cbind("AME x" = c(-1, 1)),
effects = "integrateoutRE")
bm$ContrastSummary |> data.frame()
See the alternative software vignette for more information on brmsmargins
.
15.3 Multinomial logit
Fit a model with categorical outcome (heating system choice in California houses) and logit link:
dat <- "https://vincentarelbundock.github.io/Rdatasets/csv/Ecdat/Heating.csv"
dat <- read.csv(dat)
mod <- brm(depvar ~ ic.gc + oc.gc,
data = dat,
family = categorical(link = "logit"))
15.3.1 Adjusted predictions
Compute predicted probabilities for each level of the outcome variable:
pred <- predictions(mod)
head(pred)
Extract posterior draws and plot them:
draws <- get_draws(pred)
ggplot(draws, aes(x = draw, fill = group)) +
geom_density(alpha = .2, color = "white") +
labs(x = "Predicted probability",
y = "Density",
fill = "Heating system")
Use the plot_predictions
function to plot conditional adjusted predictions for each level of the outcome variable gear
, conditional on the value of the mpg
regressor:
plot_predictions(mod, condition = "oc.gc") +
facet_wrap(~ group) +
labs(y = "Predicted probability")
15.3.2 Marginal effects
avg_slopes(mod)
15.4 Hurdle models
This section replicates some analyses from yet another amazing blog post by Andrew Heiss.
To begin, we estimate a hurdle model in brms
with random effects, using data from the gapminder
package: 704G
library(gapminder)
library(brms)
library(dplyr)
library(ggplot2)
library(ggdist)
library(cmdstanr)
library(patchwork)
library(marginaleffects)
set.seed(1024)
CHAINS <- 4
ITER <- 2000
WARMUP <- 1000
BAYES_SEED <- 1234
gapminder <- gapminder::gapminder |>
filter(continent != "Oceania") |>
# Make a bunch of GDP values 0
mutate(prob_zero = ifelse(lifeExp < 50, 0.3, 0.02),
will_be_zero = rbinom(n(), 1, prob = prob_zero),
gdpPercap = ifelse(will_be_zero, 0, gdpPercap)) |>
select(-prob_zero, -will_be_zero) |>
# Make a logged version of GDP per capita
mutate(log_gdpPercap = log1p(gdpPercap)) |>
mutate(is_zero = gdpPercap == 0)
mod <- brm(
bf(gdpPercap ~ lifeExp + year + (1 + lifeExp + year | continent),
hu ~ lifeExp),
data = gapminder,
backend = "cmdstanr",
family = hurdle_lognormal(),
cores = 2,
chains = CHAINS, iter = ITER, warmup = WARMUP, seed = BAYES_SEED,
silent = 2)
15.4.1 Adjusted predictions
Adjusted predictions for every observation in the original data:
predictions(mod) |> head()
Adjusted predictions for the hu
parameter:
predictions(mod, dpar = "hu") |> head()
Predictions on a different scale:
predictions(mod, type = "link", dpar = "hu") |> head()
Plot adjusted predictions as a function of lifeExp
:
plot_predictions(
mod,
condition = "lifeExp") +
labs(y = "mu") +
plot_predictions(
mod,
dpar = "hu",
condition = "lifeExp") +
labs(y = "hu")
Predictions with more than one condition and the re_formula
argument from brms
:
plot_predictions(
mod,
re_formula = NULL,
condition = c("lifeExp", "continent"))
15.4.2 Extract draws with get_draws()
The get_draws()
function extract raw samples from the posterior from objects produced by marginaleffects
. This allows us to use richer geoms and summaries, such as those in the ggdist
package:
predictions(
mod,
re_formula = NULL,
newdata = datagrid(model = mod,
continent = gapminder$continent,
year = c(1952, 2007),
lifeExp = seq(30, 80, 1))) |>
get_draws() |>
ggplot(aes(lifeExp, draw, fill = continent, color = continent)) +
stat_lineribbon(alpha = .25) +
facet_grid(year ~ continent)
15.4.3 Average Contrasts
What happens to gdpPercap
when lifeExp
increases by one?
avg_comparisons(mod)
What happens to gdpPercap
when lifeExp
increases by one standard deviation?
avg_comparisons(mod, variables = list(lifeExp = "sd"))
What happens to gdpPercap
when lifeExp
increases from 50 to 60 and year
simultaneously increases its min to its max?
avg_comparisons(
mod,
variables = list(lifeExp = c(50, 60), year = "minmax"),
cross = TRUE)
Plot draws from the posterior distribution of average contrasts (not the same thing as draws from the posterior distribution of contrasts):
avg_comparisons(mod) |>
get_draws() |>
ggplot(aes(estimate, term)) +
stat_dotsinterval() +
labs(x = "Posterior distribution of average contrasts", y = "")
15.4.4 Marginal effects (slopes)
Average Marginal Effect of lifeExp
on different scales and for different parameters:
avg_slopes(mod)
avg_slopes(mod, type = "link")
avg_slopes(mod, dpar = "hu")
avg_slopes(mod, dpar = "hu", type = "link")
Plot Conditional Marginal Effects
plot_slopes(
mod,
variables = "lifeExp",
condition = "lifeExp") +
labs(y = "mu") +
plot_slopes(
mod,
dpar = "hu",
variables = "lifeExp",
condition = "lifeExp") +
labs(y = "hu")
Or we can call slopes()
or comparisons()
with get_draws()
function to have even more control:
comparisons(
mod,
type = "link",
variables = "lifeExp",
newdata = datagrid(lifeExp = c(40, 70), continent = gapminder$continent)) |>
get_draws() |>
ggplot(aes(draw, continent, fill = continent)) +
stat_dotsinterval() +
facet_grid(lifeExp ~ .) +
labs(x = "Effect of a 1 unit change in Life Expectancy")
15.5 Bayesian estimates and credible intervals
For bayesian models like those produced by the brms
or rstanarm
packages, the marginaleffects
package functions report the median of the posterior distribution as their main estimates.
The default credible intervals are equal-tailed intervals (quantiles), and the default function to identify the center of the distribution is the median. Users can customize the type of intervals reported by setting global options. Note that both the reported estimate and the intervals change slightly:
library(insight)
library(marginaleffects)
mod <- insight::download_model("brms_1")
options(marginaleffects_posterior_interval = "hdi")
options(marginaleffects_posterior_center = mean)
avg_comparisons(mod)
options(marginaleffects_posterior_interval = "eti")
options(marginaleffects_posterior_center = stats::median)
avg_comparisons(mod)
15.6 Random variables: posterior
and ggdist
Recent versions of the posterior
, brms
, and ggdist
packages make it easy to draw, summarize and plot random variables. The get_draws()
can produce objects of class rvar
which make it easy to use those features by returning a data frame with a column of type rvar
:
avg_comparisons(mod) |>
get_draws(shape = "rvar") |>
ggplot(aes(y = term, xdist = rvar)) +
stat_slabinterval()
15.7 Non-linear hypothesis testing
We begin by estimating a model:
Notice that we can compute average contrasts in two different ways, using the avg_comparisons()
function or the comparison
argument:
avg_comparisons(mod)
comparisons(mod, comparison = "differenceavg")
Now, we use the hypothesis
argument to compare the first to the second rows of the comparisons()
output:
comparisons(
mod,
comparison = "differenceavg",
hypothesis = "b2 - b1 = 0.2")
The hypothesis()
function of the brms
package can also perform non-linear hypothesis testing, and it generates some convenient statistics and summaries. This function accepts a D-by-P matrix of draws from the posterior distribution, where D is the number of draws and N is the number of parameters. We can obtain such a matrix using the get_draws(x, shape = "DxP")
, and we can simply add a couple calls to our chain of operations:
avg_comparisons(mod, comparison = "differenceavg") |>
get_draws(shape = "DxP") |>
brms::hypothesis("b2 - b1 > .2")
15.8 Distributional parameters
Some brms
models allow users to model distributional parameters:
We can use marginaleffects
to compute quantities based on posterior draws of those parameters by specifying the dpar
argument, which will be passed internally to brms
’s prediction functions. For example:
avg_predictions(mod)
avg_predictions(mod, dpar = "sigma")
avg_slopes(mod, dpar = "sigma")
16 Manual computation: Counterfactual comparisons
Here is an example which replicates comparisons()
output manually. Hopefully this will help some readers understand what is going on under the hood:
library(marginaleffects)
data("ChickWeight")
mod = brm(data = ChickWeight,
weight ~ Time * Diet + (Time|Chick),
seed = 123,
backend = "cmdstanr")
# NA
comparisons(mod,
variables = "Time",
by = "Diet",
re_formula = NA)
d0 <- ChickWeight
d1 <- transform(d0, Time = Time + 1)
p0 <- posterior_epred(mod, newdata = d0, re_formula = NA)
p1 <- posterior_epred(mod, newdata = d1, re_formula = NA)
p <- p1 - p0
cmp <- apply(p, 1, function(x) tapply(x, ChickWeight$Diet, mean))
apply(cmp, 1, quantile, prob = .025)
# NULL
comparisons(mod,
variables = "Time",
by = "Diet",
re_formula = NULL)
d0 <- ChickWeight
d1 <- transform(d0, Time = Time + 1)
p0 <- posterior_epred(mod, newdata = d0, re_formula = NULL)
p1 <- posterior_epred(mod, newdata = d1, re_formula = NULL)
p <- p1 - p0
cmp <- apply(p, 1, function(x) tapply(x, ChickWeight$Diet, mean))
apply(cmp, 1, quantile, prob = .025)