graph LR; W --> X; W --> Y; X --> Y;
G-Computation
This vignette has 3 goals:
- Give a concise introduction to the idea of “Parametric g-Formula”
- Highlight the equivalence between one form of g-estimation and the “Average Contrasts” computed by
marginaleffects
- Show how to obtain estimates, standard errors, and confidence intervals via the Parametric g-Formula, using a single line of
marginaleffects
code. This is convenient because, typically, analysts have to construct counterfactual datasets manually and must bootstrap their estimates.
The “Parametric g-Formula” is often used for causal inference in observational data.
The explanations and illustrations that follow draw heavily on Chapter 13 of this excellent book (free copy available online):
Hernán MA, Robins JM (2020). Causal Inference: What If. Boca Raton: Chapman & Hall/CRC.
What is the parametric g-formula?
The parametric g-formula is a method of standardization which can be used to address confounding problems in causal inference with observational data. It relies on the same identification assumptions as Inverse Probability Weighting (IPW), but uses different modeling assumptions. Whereas IPW models the treatment equation, standardization models the mean outcome equation. As Hernán and Robins note:
“Both IP weighting and standardization are estimators of the g-formula, a general method for causal inference first described in 1986. … We say that standardization is a”plug-in g-formula estimator” because it simply replaces the conditional mean outcome in the g-formula by its estimates. When, like in Chapter 13, those estimates come from parametric models, we refer to the method as the parametric g-formula.”
How does it work?
Imagine a causal model like this:
We want to estimate the effect of a binary treatment \(X\) on outcome \(Y\), but there is a confounding variable \(W\). We can use standardization with the parametric g-formula to handle this. Roughly speaking, the procedure is as follows:
- Use the observed data to fit a regression model with \(Y\) as outcome, \(X\) as treatment, and \(W\) as control variable (with perhaps some polynomials and/or interactions if there are multiple control variables).
- Create a new dataset exactly identical to the original data, but where \(X=1\) in every row.
- Create a new dataset exactly identical to the original data, but where \(X=0\) in every row.
- Use the model from Step 1 to compute adjusted predictions in the two counterfactual datasets from Steps 2 and 3.
- The quantity of interest is the difference between the means of adjusted predictions in the two counterfactual datasets.
This is equivalent to computing an “Average Contrast”, in which the value of \(X\) moves from 0 to 1. Thanks to this equivalence, we can apply the parametric g-formula method using a single line of code in marginaleffects
, and obtain delta method standard errors automatically.
Example with real-world data
Let’s illustrate this method by replicating an example from Chapter 13 of Hernán and Robins. The data come from the National Health and Nutrition Examination Survey Data I Epidemiologic Follow-up Study (NHEFS). The outcome is wt82_71
, a measure of weight gain. The treatment is qsmk
, a binary measure of smoking cessation. There are many confounders.
Step 1 is to fit a regression model of the outcome on the treatment and control variables:
library(boot)
library(marginaleffects)
f <- wt82_71 ~ qsmk + sex + race + age + I(age * age) + factor(education) +
smokeintensity + I(smokeintensity * smokeintensity) + smokeyrs +
I(smokeyrs * smokeyrs) + factor(exercise) + factor(active) + wt71 +
I(wt71 * wt71) + I(qsmk * smokeintensity)
url <- "https://raw.githubusercontent.com/vincentarelbundock/modelarchive/main/data-raw/nhefs.csv"
nhefs <- read.csv(url)
nhefs <- na.omit(nhefs[, all.vars(f)])
fit <- glm(f, data = nhefs)
import polars as pl
import statsmodels.formula.api as smf
from marginaleffects import *
= "wt82_71 ~ qsmk + sex + race + age + pow(age,2) + C(education) + \
f smokeintensity + pow(smokeintensity,2) + smokeyrs + \
pow(smokeyrs,2) + C(exercise) + C(active) + wt71 + \
pow(wt71,2) + qsmk*smokeintensity"
= pl.read_csv("https://raw.githubusercontent.com/vincentarelbundock/modelarchive/main/data-raw/nhefs.csv")
nhefs
= ["wt82_71", "qsmk", "sex", "race", "age", "education", "smokeintensity", "smokeyrs", "exercise", "active", "wt71"]
variables = nhefs.filter(
nhefs ~pl.any_horizontal(pl.col(variables).is_null())
)
= smf.glm(f, data=nhefs.to_pandas()).fit() fit
Steps 2 and 3 require us to replicate the full dataset by setting the qsmk
treatment to counterfactual values. We can do this automatically by calling comparisons()
.
TLDR
These simple commands do everything we need to apply the parametric g-formula:
avg_comparisons(fit, variables = list(qsmk = 0:1))
Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
3.52 0.44 7.99 <0.001 49.4 2.65 4.38
Term: qsmk
Type: response
Comparison: mean(1) - mean(0)
Columns: term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted
cmp = avg_comparisons(fit, variables = {'qsmk' : [0,1]})
print(cmp)
shape: (1, 9)
┌──────┬──────────────────────────┬──────────┬───────────┬───┬──────────┬──────┬──────┬───────┐
│ Term ┆ Contrast ┆ Estimate ┆ Std.Error ┆ … ┆ P(>|z|) ┆ S ┆ 2.5% ┆ 97.5% │
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ str ┆ ┆ str ┆ str ┆ str ┆ str │
╞══════╪══════════════════════════╪══════════╪═══════════╪═══╪══════════╪══════╪══════╪═══════╡
│ qsmk ┆ mean(True) - mean(False) ┆ 3.52 ┆ 0.44 ┆ … ┆ 1.33e-15 ┆ 49.4 ┆ 2.65 ┆ 4.38 │
└──────┴──────────────────────────┴──────────┴───────────┴───┴──────────┴──────┴──────┴───────┘
Columns: term, contrast, estimate, std_error, statistic, p_value, s_value, conf_low, conf_high
The rest of the vignette walks through the process in a bit more detail and compares to replication code from Hernán and Robins.
Adjusted Predictions
We can compute average predictions in the original data, and average predictions in the two counterfactual datasets like this:
## average predicted outcome in the original data
p <- predictions(fit)
mean(p$estimate)
[1] 2.6383
## average predicted outcome in the two counterfactual datasets
p <- predictions(fit, newdata = datagrid(qsmk = 0:1, grid_type = "counterfactual"))
aggregate(estimate ~ qsmk, data = p, FUN = mean)
qsmk estimate
1 0 1.756213
2 1 5.273587
## average predicted outcome in the original data
= predictions(fit)
p print(p['estimate'].mean())
2.6382997865627953
## average predicted outcome in the two counterfactual datasets
= predictions(fit, newdata = datagrid(qsmk = [0,1], grid_type = "counterfactual"))
p = p.group_by('qsmk').agg(pl.col('estimate').mean())
agg print(agg)
shape: (2, 2)
┌──────┬──────────┐
│ qsmk ┆ estimate │
│ --- ┆ --- │
│ i64 ┆ f64 │
╞══════╪══════════╡
│ 1 ┆ 5.273587 │
│ 0 ┆ 1.756213 │
└──────┴──────────┘
In the R
code that accompanies their book, Hernán and Robins compute the same quantities manually, as follows:
## create a dataset with 3 copies of each subject
nhefs$interv <- -1 # 1st copy: equal to original one
interv0 <- nhefs # 2nd copy: treatment set to 0, outcome to missing
interv0$interv <- 0
interv0$qsmk <- 0
interv0$wt82_71 <- NA
interv1 <- nhefs # 3rd copy: treatment set to 1, outcome to missing
interv1$interv <- 1
interv1$qsmk <- 1
interv1$wt82_71 <- NA
onesample <- rbind(nhefs, interv0, interv1) # combining datasets
## linear model to estimate mean outcome conditional on treatment and confounders
## parameters are estimated using original observations only (nhefs)
## parameter estimates are used to predict mean outcome for observations with
## treatment set to 0 (interv=0) and to 1 (interv=1)
std <- glm(f, data = onesample)
onesample$predicted_meanY <- predict(std, onesample)
## estimate mean outcome in each of the groups interv=0, and interv=1
## this mean outcome is a weighted average of the mean outcomes in each combination
## of values of treatment and confounders, that is, the standardized outcome
mean(onesample[which(onesample$interv == -1), ]$predicted_meanY)
[1] 2.6383
[1] 1.756213
[1] 5.273587
The R
code manually computing these quantities that accompanies Hernán and Robins’ book can be translated into Python
as follows :
# create a dataset with 3 copies of each subject
= nhefs.with_columns(
nhefs -1).alias('interv') # 1st copy: equal to original one
pl.lit(
)
= nhefs # 2nd copy: treatment set to 0, outcome to missing
interv0 = interv0.with_columns(
interv0 0).alias('interv'),
pl.lit(0).alias('qsmk').cast(pl.Int64),
pl.lit(None).alias('wt82_71')
pl.lit(
)
= nhefs # 3rd copy: treatment set to 1, outcome to missing
interv1 = interv1.with_columns(
interv1 1).alias('interv'),
pl.lit(1).alias('qsmk').cast(pl.Int64),
pl.lit(None).alias('wt82_71')
pl.lit(
)
## linear model to estimate mean outcome conditional on treatment and confounders
## parameters are estimated using original observations only (nhefs)
= smf.glm(f, data = nhefs.to_pandas()).fit()
std
= pl.concat([nhefs, interv0, interv1], how='vertical') # combining datasets
onesample
## parameter estimates are used to predict mean outcome for observations with
## treatment set to 0 (interv=0) and to 1 (interv=1)
= onesample.with_columns(
onesample = 'predicted_meanY', values = std.predict(onesample.to_pandas()))
pl.Series(name
)
## estimate mean outcome in each of the groups interv=0, and interv=1
## this mean outcome is a weighted average of the mean outcomes in each combination
## of values of treatment and confounders, that is, the standardized outcome
filter(pl.col('interv') == -1)['predicted_meanY'].mean() onesample.
2.638299786562795
filter(pl.col('interv') == 0)['predicted_meanY'].mean() onesample.
1.7562131154657705
filter(pl.col('interv') == 1)['predicted_meanY'].mean() onesample.
5.27358731635129
It may be useful to note that the datagrid()
function provided by marginaleffects
can create counterfactual datasets automatically. This is equivalent to the onesample
dataset:
Contrast
Now we want to compute the treatment effect with the parametric g-formula, which is the difference in average predicted outcomes in the two counterfactual datasets. This is equivalent to taking the average contrast with the comparisons()
function. There are three important things to note in the command that follows:
- The
variables
argument is used to indicate that we want to estimate a “contrast” between adjusted predictions whenqsmk
is equal to 1 or 0. -
comparisons()
automatically produces estimates of uncertainty.
avg_comparisons(std, variables = list(qsmk = 0:1))
Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
3.52 0.44 7.99 <0.001 49.4 2.65 4.38
Term: qsmk
Type: response
Comparison: mean(1) - mean(0)
Columns: term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted
cmp = avg_comparisons(std, variables = {"qsmk" : [0,1]})
print(cmp)
shape: (1, 9)
┌──────┬──────────────────────────┬──────────┬───────────┬───┬──────────┬──────┬──────┬───────┐
│ Term ┆ Contrast ┆ Estimate ┆ Std.Error ┆ … ┆ P(>|z|) ┆ S ┆ 2.5% ┆ 97.5% │
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ str ┆ ┆ str ┆ str ┆ str ┆ str │
╞══════╪══════════════════════════╪══════════╪═══════════╪═══╪══════════╪══════╪══════╪═══════╡
│ qsmk ┆ mean(True) - mean(False) ┆ 3.52 ┆ 0.44 ┆ … ┆ 1.33e-15 ┆ 49.4 ┆ 2.65 ┆ 4.38 │
└──────┴──────────────────────────┴──────────┴───────────┴───┴──────────┴──────┴──────┴───────┘
Columns: term, contrast, estimate, std_error, statistic, p_value, s_value, conf_low, conf_high
Under the hood, comparisons()
did exactly what we described in the g-formula steps above:
We can obtain the same result by manually computing the quantities, using the replication code from Hernán and Robins:
Although manual computation is simple, it does not provide uncertainty estimates. In contrast, comparisons()
has already computed the standard error and confidence interval using the delta method.
Instead of the delta method, most analysts will rely on bootstrapping. For example, the replication code from Hernán and Robins does this:
## function to calculate difference in means
standardization <- function(data, indices) {
# create a dataset with 3 copies of each subject
d <- data[indices, ] # 1st copy: equal to original one`
d$interv <- -1
d0 <- d # 2nd copy: treatment set to 0, outcome to missing
d0$interv <- 0
d0$qsmk <- 0
d0$wt82_71 <- NA
d1 <- d # 3rd copy: treatment set to 1, outcome to missing
d1$interv <- 1
d1$qsmk <- 1
d1$wt82_71 <- NA
d.onesample <- rbind(d, d0, d1) # combining datasets
# linear model to estimate mean outcome conditional on treatment and confounders
# parameters are estimated using original observations only (interv= -1)
# parameter estimates are used to predict mean outcome for observations with set
# treatment (interv=0 and interv=1)
fit <- glm(f, data = d.onesample)
d.onesample$predicted_meanY <- predict(fit, d.onesample)
# estimate mean outcome in each of the groups interv=-1, interv=0, and interv=1
return(mean(d.onesample$predicted_meanY[d.onesample$interv == 1]) -
mean(d.onesample$predicted_meanY[d.onesample$interv == 0]))
}
## bootstrap
results <- boot(data = nhefs, statistic = standardization, R = 1000)
## generating confidence intervals
se <- sd(results$t[, 1])
meant0 <- results$t0
ll <- meant0 - qnorm(0.975) * se
ul <- meant0 + qnorm(0.975) * se
bootstrap <- data.frame(
" " = "Treatment - No Treatment",
estimate = meant0,
std.error = se,
conf.low = ll,
conf.high = ul,
check.names = FALSE)
bootstrap
estimate std.error conf.low conf.high
1 Treatment - No Treatment 3.517374 0.4766104 2.583235 4.451513
The results are close to those that we obtained with comparisons()
, but the confidence interval differs slightly because of the difference between bootstrapping and the delta method.
avg_comparisons(fit, variables = list(qsmk = 0:1))
Estimate Std. Error z Pr(>|z|) S 2.5 % 97.5 %
3.52 0.44 7.99 <0.001 49.4 2.65 4.38
Term: qsmk
Type: response
Comparison: mean(1) - mean(0)
Columns: term, contrast, estimate, std.error, statistic, p.value, s.value, conf.low, conf.high, predicted_lo, predicted_hi, predicted
cmp = avg_comparisons(std, variables = {"qsmk" : [0,1]})
print(cmp)
shape: (1, 9)
┌──────┬──────────────────────────┬──────────┬───────────┬───┬──────────┬──────┬──────┬───────┐
│ Term ┆ Contrast ┆ Estimate ┆ Std.Error ┆ … ┆ P(>|z|) ┆ S ┆ 2.5% ┆ 97.5% │
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ str ┆ ┆ str ┆ str ┆ str ┆ str │
╞══════╪══════════════════════════╪══════════╪═══════════╪═══╪══════════╪══════╪══════╪═══════╡
│ qsmk ┆ mean(True) - mean(False) ┆ 3.52 ┆ 0.44 ┆ … ┆ 1.33e-15 ┆ 49.4 ┆ 2.65 ┆ 4.38 │
└──────┴──────────────────────────┴──────────┴───────────┴───┴──────────┴──────┴──────┴───────┘
Columns: term, contrast, estimate, std_error, statistic, p_value, s_value, conf_low, conf_high