EXPERIMENTAL – Enable Automatic Differentiation with JAX

Description

This function enables or disables automatic differentiation using the JAX package in Python, which can considerably speed up and increase the accuracy of standard errors when a model includes many parameters.

Usage

autodiff(autodiff = NULL, install = FALSE)

Arguments

autodiff Logical flag. If TRUE, enables automatic differentiation with JAX. If FALSE (default), disables automatic differentiation and reverts to finite difference methods.
install Logical flag. If TRUE, installs the marginaleffectsAD Python package via reticulate::py_install(). Default is FALSE. This is only necessary if you are self-managing a Python installation.

Details

When autodiff = TRUE, this function:

  • Imports the marginaleffectsAD Python package via reticulate::py_install()

  • Sets the internal jacobian function to use JAX-based automatic differentiation

  • Provides faster and more accurate gradient computation for supported models

  • Falls back on the default finite difference method for unsupported models and calls.

Currently supports:

  • Model types: lm, glm, ols, lrm

  • Functions: predictions() and comparisons(), along with avg_ and plot_ variants.

  • type: "response" or "link"

  • by: TRUE, FALSE, or character vector.

  • comparison: "difference" and "ratio"

For unsupported models or options, the function automatically falls back to finite difference methods with a warning.

Value

No return value. Called for side effects of enabling/disabling automatic differentiation.

Python Configuration

By default, no manual configuration of Python should be necessary. On most machines, unless you have explicitly configured reticulate, reticulate defaults to an automatically managed ephemeral virtual environment with all Python requirements declared via reticulate::py_require().

If you prefer to use a manually managed Python installation, you can direct reticulate and specify which Python executable or environment to use. reticulate selects a Python installation using its Order of Discovery. As a convenience autodiff(install=TRUE) will install marginaleffectsAD in a self-managed virtual environment.

To specify an alternate Python version:

library(reticulate)
use_python("/usr/local/bin/python")

To use a virtual environment:

use_virtualenv("myenv")

These configuration commands should be called before calling autodiff().

Examples

library("marginaleffects")

# Install the Python package (only needed once)
autodiff(install = TRUE)

# Enable automatic differentiation
autodiff(TRUE)

# Fit a model and compute marginal effects
mod <- glm(am ~ hp + wt, data = mtcars, family = binomial)
avg_comparisons(mod) # Will use JAX for faster computation

# Disable automatic differentiation
autodiff(autodiff = FALSE)