Fit a sklearn model with output that is compatible with pymarginaleffects.
This function streamlines the process of fitting sklearn models by:
- Parsing the formula
- Handling missing values
- Creating model matrices
- Fitting the model with specified options
Parameters
formula
: (str) Model formula
- Example: “outcome ~ distance + incentive”
data
: (pandas.DataFrame) Dataframe with the response variable and predictors.
engine
: (callable) sklearn model class (e.g., LinearRegression, LogisticRegression)
kwargs_engine
: (dict, default={}) Additional arguments passed to the model initialization.
- Example:
{'weights': weights_array}
Returns
(ModelSklearn) A fitted model wrapped in the ModelSklearn class for compatibility with marginaleffects.
Examples
from marginaleffects import *
from statsmodels.formula.api import ols
import polars.selectors as cs
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, FunctionTransformer
from sklearn.linear_model import LinearRegression
from sklearn.compose import make_column_transformer
from xgboost import XGBRegressor
# Linear regression: Scikit-learn {.unnumbered}
military = get_dataset("military")
mod_sk = fit_sklearn(
"rank ~ officer + hisp + branch",
data=military,
engine=LinearRegression(),
)
avg_predictions(mod_sk, by="branch")
# Linear regression: Statsmodels {.unnumbered}
mod_sm = ols("rank ~ officer + hisp + branch", data=military.to_pandas()).fit()
avg_predictions(mod_sm, by="branch")
# XGBoost: Scikit-learn {.unnumbered}
airbnb = get_dataset("airbnb")
train, test = train_test_split(airbnb)
catvar = airbnb.select(~cs.numeric()).columns
def selector(data):
y = data.select(cs.by_name("price", require_all=False))
X = data.select(~cs.by_name("price", require_all=False))
return y, X
preprocessor = make_column_transformer(
(OneHotEncoder(), catvar),
remainder=FunctionTransformer(lambda x: x.to_numpy()),
)
pipeline = make_pipeline(preprocessor, XGBRegressor())
mod = fit_sklearn(selector, data=train, engine=pipeline)
avg_predictions(mod, newdata=test, by="unit_type")
avg_comparisons(mod, variables={"bedrooms": 2}, newdata=test)
shape: (1, 3)
str |
str |
f64 |
"bedrooms" |
"+2" |
13.856534 |
Notes
The fitted model includes additional attributes:
data
: The processed data after listwise deletion
formula
: The original formula string
formula_engine
: Set to “sklearn”
model
: The fitted sklearn model object