library(marginaleffects)
library(microbenchmark)
library(tictoc)
options(width = 300)
dat <- get_dataset("airbnb")
mod <- lm(
price ~ bathrooms * bedrooms * unit_type * Breakfast * Gym,
data = dat)
37 Performance
37.1 Why is marginaleffects
slow, sometimes?
There are four primary reasons why some marginaleffects
calls are slow:
- Slow
predict()
function. - Too many comparisons.
- Large grid.
- Expensive standard errors.
Let’s consider each of these points in turn, along with potential solutions.
37.2 Slow predict()
function
To compute predictions, comparisons, and slopes, marginaleffects
must call predict()
repeatedly. This function is supplied by the package that implements the model object, and its speed can vary widely.
Unfortunately, since this function is not supplied by marginaleffects
itself, there is not much that can be done to speed it up. Therefore, we move on directly to the other sources of sluggishness.
37.3 Too many comparisons
When using the slopes()
or comparisons()
families of functions, the default is to compute the “effect” of all focal preditors one after the other. When there are many predictors, this can be expensive.
37.3.1 Solution: variables
argument
A simple solution is to specify the variables
argument to only compute estimates for the predictors of interest. Compare:
tic()
cmp <- avg_comparisons(mod, variables = "bathrooms")
toc()
3.142 sec elapsed
tic()
cmp <- avg_comparisons(mod)
toc()
9.607 sec elapsed
37.4 Large grid
There are two main reasons why a grid might be large.
First, the dataset used to fit the model has many rows, and you have not specified a newdata
argument. In the fitted example above, the airbnb
dataset is large. Therefore, if we do not specify a newdata
argument, marginaleffects
will compute one estimate (and standard error) for each row in the dataset. This can be very costly.
p <- predictions(mod)
dim(p)
[1] 52717 64
The second situation in which a grid can be large occurs when users have specified a grid in the newdata
argument, perhaps using the datagrid()
helper, and that grid is very large.
Consider a simple dataset with 4 categorical variables, each with 26 possible levels:
let <- data.frame(
A = sample(letters, 1000, replace = TRUE),
B = sample(letters, 1000, replace = TRUE),
C = sample(letters, 1000, replace = TRUE),
D = sample(letters, 1000, replace = TRUE)
)
nrow(let)
[1] 1000
This dataset has 1000 rows. However, if we use the datagrid()
function to create a “balanced grid”, we end up with 264=456,976 rows:
37.4.1 Solution A: Estimates at representative values
Instead of compute predictions, comparisons, or slopes for every row of a large grid, it is much faster to compute just a few estimates. For example, you may want to compute predictions for an individual with average (or modal) characteristics, or for a unit with specific values of interest.
predictions(mod, newdata = "mean")
Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
115 0.458 252 <0.001 Inf 114 116
Type: response
predictions(mod, newdata = datagrid(Gym = 1, bathrooms = 2, bedrooms = 6))
Gym bathrooms bedrooms Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
1 2 6 306 13 23.6 <0.001 405.1 281 332
Type: response
37.4.2 Solution B: Random subset of the grid
An alternative is to sample a subset of rows at random from the large grid. If the subset is large enough, the statistics we compute on that basis will often be very similar to the full-sample values.
Here, we use avg_comparisons()
to compute the “average effect” of bathrooms
on price
, and find that the estimates are reasonably close when we use the full grid or a subset.
set.seed(48103)
# random subset of the grid
sub <- dat[sample(nrow(dat), 1000), ]
avg_comparisons(mod, variables = "bathrooms", newdata = sub)
Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
26.9 0.601 44.8 <0.001 Inf 25.8 28.1
Term: bathrooms
Type: response
Comparison: +1
# full grid
avg_comparisons(mod, variables = "bathrooms")
Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
27.2 0.599 45.4 <0.001 Inf 26 28.4
Term: bathrooms
Type: response
Comparison: +1
37.4.3 Solution C: Weighted grid
When all the predictor variables are categorical, we can use a third approach: use a weighted grid of unique rows. Consider again our airbnb
model. There are five categorical predictors in this model:
pred <- insight::find_predictors(mod, flatten = TRUE)
pred
[1] "bathrooms" "bedrooms" "unit_type" "Breakfast" "Gym"
What we can do is create a grid of unique rows, and assign weights to each row based on how many times it appears in the original dataset.
This grid is much smaller than the original dataset:
Now we can compute a quantity of interest with marginaleffects
, using the wts
argument to “expand” the grid and obtain the same result as if we had used the full dataset, but much faster.
# full dataset
tic()
p_full <- avg_predictions(mod)
toc()
0.485 sec elapsed
# weighted grid
tic()
p_subset <- avg_predictions(mod, newdata = grid, wts = "weight")
toc()
0.043 sec elapsed
p_full
Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
94.7 0.24 395 <0.001 Inf 94.2 95.2
Type: response
p_subset
Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
94.7 0.24 395 <0.001 Inf 94.2 95.2
Type: response
37.5 Standard errors are expensive
The final reason why some marginaleffects
calls take time is that standard errors are expensive to compute. This is especially true when the fitted model includes many parameters (coefficients), because the default strategy to compute standard errors involves calling predict()
at least once per coefficient.
37.5.1 Solution A: No standard errors
It is always much faster to compute estimates without standard errors. If you do not need them, set vcov=FALSE
.
tic()
p <- avg_predictions(mod, vcov = FALSE)
toc()
0.035 sec elapsed
tic()
p <- avg_predictions(mod)
toc()
0.476 sec elapsed
37.5.2 Solution B: Automatic differentiation
For some models, marginaleffects
can compute standard errors using automatic differentiation (AD). This is often much faster (and sometimes more accurate) than the default finite difference approach.
To enable AD, we need to install the marginaleffectsAD
package for Python, and call it via the reticulate
package. This sounds more complicated than it is. In most cases, all the user has to do is call the autodiff()
function supplied by marginaleffects
.
install.packages("reticulate")
library("marginaleffects")
autodiff(install = TRUE)
After installing the dependencies, we can call autodiff(TRUE)
to enable AD for the current R session, and we call the microbenchmark
package to compare the performance of two similar calls. The AD approach is much faster in this example.
automatic_differentiation <- function() {
autodiff(TRUE)
avg_predictions(mod)
}
finite_differences <- function() {
autodiff(FALSE)
avg_predictions(mod)
}
microbenchmark(
automatic_differentiation(),
finite_differences(),
times = 5
)
Unit: milliseconds
expr min lq mean median uq max neval cld
automatic_differentiation() 40.09452 47.33257 1181.792 47.6347 62.22422 5711.6738 5 a
finite_differences() 468.62438 475.61513 505.874 484.5733 549.58274 550.9742 5 a
37.5.3 Solution C: Parallel processing
For models and calls that are not supported by automatic differentiation, another option is to compute standard errors in parallel.
Parallel processing does not always speed up marginaleffects
much, because of the overhead involved in passing large datasets between different cores or forked processes. This strategy is more likely to be useful when:
- The modelling package’s
predict()
function is slow. - There are many parameters in the model.
- The grid is not large enough to impose too much overhead.
To use parallel processing, we call the plan
function from the future
package and we set some global options.
The global option for parallel processing is different for the inferences()
function. See the documentation of that function for details.
Here is an example with nycflights13
data and a complex model with many parameters.
library(mgcv) |> suppressPackageStartupMessages()
library(tictoc)
library(future)
library(nycflights13)
library(marginaleffects)
data("flights")
options(future.globals.maxSize = 8000 * 1024^2) # 8 GB memory
packageVersion("marginaleffects")
[1] '0.29.0.7'
cores <- 8
plan(multicore, workers = cores)
flights <- flights |>
transform(date = as.Date(paste(year, month, day, sep = "/"))) |>
transform(date.num = as.numeric(date - min(date))) |>
transform(wday = as.POSIXlt(date)$wday) |>
transform(time = as.POSIXct(paste(hour, minute, sep = ":"), format = "%H:%M")) |>
transform(time.dt = difftime(time, as.POSIXct('00:00', format = '%H:%M'), units = 'min')) |>
transform(time.num = as.numeric(time.dt)) |>
transform(dep_delay = ifelse(dep_delay < 0, 0, dep_delay)) |>
transform(dep_delay = ifelse(is.na(dep_delay), 0, dep_delay)) |>
transform(carrier = factor(carrier)) |>
transform(dest = factor(dest)) |>
transform(origin = factor(origin))
model <- bam(dep_delay ~
s(date.num, bs = "cr") +
s(wday, bs = "cc", k = 3) +
s(time.num, bs = "cr") +
s(carrier, bs = "re") +
origin +
s(distance, bs = "cr") +
s(dest, bs = "re"),
data = flights,
family = poisson,
discrete = TRUE,
nthreads = cores)
# there are many parameters
length(coef(model))
[1] 153
Now, compare three calls to predictions()
.
# No standard errors is very fast:
tic()
p1 <- predictions(model, vcov = FALSE)
toc()
0.415 sec elapsed
# Parallel standard errors is slower
options("marginaleffects_parallel" = TRUE)
tic()
p1 <- predictions(model)
toc()
10.709 sec elapsed
# Sequential standard errors is even slower
options("marginaleffects_parallel" = FALSE)
tic()
p2 <- predictions(model)
toc()
32.901 sec elapsed