This notebook presents some advanced datagrid() use. Readers should be familiar with the content of Section 3.2 before diving in.
22.1 The by argument
Sometimes we want to hold variables at constant values within groups. The by argument in datagrid() allows us to create grids where certain variables are fixed at specific values within subgroups.
For example, we might want to compare predictions at the mean of a continuous variable. In this example, the Sepal.Width column is held at the global mean in every row of the grid:
Now, we hold Sepal.Width at its mean value within each Species group. This allows us to compare the effect of species while controlling for the typical sepal width in each group:
predictions(mod, newdata =datagrid(Species =unique, Sepal.Width =mean, by ="Species"))
22.2 Sampling from large grids with grid_type = "dataframe"
When working with many categorical variables, a balanced grid can become prohibitively large. For instance, with 5 variables each taking 26 possible values, a full factorial grid would have 26^5 = 11,881,376 rows.
The grid_type = "dataframe" argument combined with a custom FUN allows us to sample uniformly from the full factorial space. The resulting grid will “approach” the true factorial grid, but have a much smaller footprint.
# Simulate data with 5 categorical variablesset.seed(48103)n<-10000dat<-data.frame( x1 =sample(letters, n, replace =TRUE), x2 =sample(letters, n, replace =TRUE), x3 =sample(letters, n, replace =TRUE), x4 =sample(letters, n, replace =TRUE), x5 =sample(letters, n, replace =TRUE), y =rnorm(n))# Fit modelmod<-lm(y~x1+x2+x3+x4+x5, data =dat)# Sample 5 values uniformly from each variable's unique entriessample_unique<- \(x, n)sample(unique(x), n, replace =TRUE)grid_sampled<-datagrid( model =mod, grid_type ="dataframe", FUN = \(x)sample_unique(x, 5))grid_sampled
rowid x1 x2 x3 x4 x5 y
1 1 u z b j b -0.6469524
2 2 h l g m p 0.1698563
3 3 p f c h t -0.5320531
4 4 b h d f z -0.8316959
5 5 g c h n u 0.4131603
Now we can make predictions on the much smaller sampled grid:
Note that the grid_type="dataframe" argument tells datagrid() to construct the grid by applying FUN to each variable individually, and then by stacking the sampled values vertically using cbind(). This is different from other grids like balanced or counterfactual, which take the cartesian product of supplied values, with expand.grid().