Does anyone know how to use predictions()
in the marginaleffects()
package with tidymodels
? In this toy example, I want to get the predicted values of the variable state
while holding all other variables at their base levels or mean values.
library(liver)
library(tidymodels)
library(marginaleffects)
df_churn <- data.frame(churn)
# Create data split object
churn_split <- initial_split(df_churn, prop = 0.75,
strata = churn)
# Create the training data
churn_train <- churn_split %>%
training()
# Create the test data
churn_test <- churn_split %>%
testing()
lr_mod <-
logistic_reg(penalty = tune(), mixture = 1) %>% # penalty = lambda. mixture = alpha
set_engine("glmnet") %>%
set_mode("classification")
# pre-process recipe
churn_recipe <- recipe(churn ~ .,
data = churn_train) %>%
step_corr(all_numeric(), threshold = 0.9) %>%
step_normalize(all_numeric()) %>%
step_dummy(all_nominal(), -all_outcomes())
# model recipe = workflow
churn_wkfl <- workflow() %>%
add_model(lr_mod) %>%
add_recipe(churn_recipe)
# cv
set.seed(1)
churn_folds <- vfold_cv(churn_train,
v = 10,
strata = churn)
# grid
set.seed(1)
glmnet_tuning <- churn_wkfl %>%
tune_grid(resamples = churn_folds,
grid = 25, # let the model find the best hyperparameters
metrics = metric_set(roc_auc))
# select the best model
best_glmnet_model <- glmnet_tuning %>%
select_best(metric = 'roc_auc')
# finalize the workflow and try to get adjusted predictions
# This does not work
final_churn_wkfl <- churn_wkfl %>%
finalize_workflow(best_glmnet_model) %>%
fit(churn_train) %>%
tidy() %>%
predictions(variables = c("state"))
CodePudding user response:
Unfortunately, glmnet is not one of the supported models for marginaleffects.
You can switch this to one of the supported models (like regular glm()
) and this will work using extract_fit_engine()
.
library(tidymodels)
library(marginaleffects)
data("mlc_churn")
set.seed(123)
churn_split <- initial_split(mlc_churn, prop = 0.75, strata = churn)
churn_train <- training(churn_split)
churn_test <- testing(churn_split)
churn_recipe <- recipe(churn ~ .,
data = churn_train) %>%
step_corr(all_numeric(), threshold = 0.9) %>%
step_normalize(all_numeric()) %>%
step_dummy(all_nominal(), -all_outcomes())
# model recipe = workflow
churn_wkfl <- workflow(churn_recipe, logistic_reg())
# finalize the workflow and try to get adjusted predictions
churn_wkfl %>%
fit(churn_train) %>%
extract_fit_engine() %>%
predictions(variables = c("total_intl_calls")) %>%
as_tibble()
#> # A tibble: 5 × 71
#> rowid type predicted std.error conf.low conf.high account_length
#> <int> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 1 response 0.895 0.0119 0.870 0.916 1.76e-17
#> 2 2 response 0.917 0.00620 0.904 0.928 1.76e-17
#> 3 3 response 0.923 0.00543 0.912 0.933 1.76e-17
#> 4 4 response 0.934 0.00549 0.923 0.944 1.76e-17
#> 5 5 response 0.977 0.00840 0.953 0.989 1.76e-17
#> # … with 64 more variables: number_vmail_messages <dbl>,
#> # total_day_minutes <dbl>, total_day_calls <dbl>, total_eve_minutes <dbl>,
#> # total_eve_calls <dbl>, total_night_calls <dbl>, total_night_charge <dbl>,
#> # total_intl_minutes <dbl>, number_customer_service_calls <dbl>,
#> # state_AL <dbl>, state_AR <dbl>, state_AZ <dbl>, state_CA <dbl>,
#> # state_CO <dbl>, state_CT <dbl>, state_DC <dbl>, state_DE <dbl>,
#> # state_FL <dbl>, state_GA <dbl>, state_HI <dbl>, state_IA <dbl>, …
Created on 2022-03-25 by the reprex package (v2.0.1)
Notice that I did not use variables = c("state")
and substituted one of the continuous, numeric predictors.