Home > Back-end >  Adjusted Predictions Tidymodels
Adjusted Predictions Tidymodels

Time:03-26

Does anyone know how to use predictions() in the marginaleffects() package with tidymodels? In this toy example, I want to get the predicted values of the variable state while holding all other variables at their base levels or mean values.

library(liver)
library(tidymodels)
library(marginaleffects)

df_churn <- data.frame(churn)

# Create data split object
churn_split <- initial_split(df_churn, prop = 0.75,
                             strata = churn)

# Create the training data
churn_train <- churn_split %>% 
  training()

# Create the test data
churn_test <- churn_split %>% 
  testing()

lr_mod <- 
  logistic_reg(penalty = tune(), mixture = 1) %>% # penalty = lambda. mixture = alpha
  set_engine("glmnet") %>%
  set_mode("classification")

# pre-process recipe
churn_recipe <- recipe(churn ~ .,
                       data = churn_train) %>%
  step_corr(all_numeric(), threshold = 0.9) %>%
  step_normalize(all_numeric()) %>%
  step_dummy(all_nominal(), -all_outcomes())

# model   recipe = workflow
churn_wkfl <- workflow() %>%
  add_model(lr_mod) %>%
  add_recipe(churn_recipe)

# cv
set.seed(1)
churn_folds <- vfold_cv(churn_train,
                        v = 10,
                        strata = churn)

# grid
set.seed(1)
glmnet_tuning <- churn_wkfl %>%
  tune_grid(resamples = churn_folds,
            grid = 25, # let the model find the best hyperparameters
            metrics = metric_set(roc_auc))

# select the best model
best_glmnet_model <- glmnet_tuning %>%
  select_best(metric = 'roc_auc')

# finalize the workflow and try to get adjusted predictions
# This does not work 
final_churn_wkfl <- churn_wkfl %>%
  finalize_workflow(best_glmnet_model) %>%
  fit(churn_train) %>%
  tidy() %>%
  predictions(variables = c("state"))

CodePudding user response:

Unfortunately, glmnet is not one of the supported models for marginaleffects.

You can switch this to one of the supported models (like regular glm()) and this will work using extract_fit_engine().

library(tidymodels)
library(marginaleffects)
data("mlc_churn")

set.seed(123)
churn_split <- initial_split(mlc_churn, prop = 0.75, strata = churn)
churn_train <- training(churn_split)
churn_test <- testing(churn_split)

churn_recipe <- recipe(churn ~ .,
                       data = churn_train) %>%
  step_corr(all_numeric(), threshold = 0.9) %>%
  step_normalize(all_numeric()) %>%
  step_dummy(all_nominal(), -all_outcomes())

# model   recipe = workflow
churn_wkfl <- workflow(churn_recipe, logistic_reg()) 

# finalize the workflow and try to get adjusted predictions
churn_wkfl %>%
  fit(churn_train) %>%
  extract_fit_engine() %>%
  predictions(variables = c("total_intl_calls")) %>%
  as_tibble()
#> # A tibble: 5 × 71
#>   rowid type     predicted std.error conf.low conf.high account_length
#>   <int> <chr>        <dbl>     <dbl>    <dbl>     <dbl>          <dbl>
#> 1     1 response     0.895   0.0119     0.870     0.916       1.76e-17
#> 2     2 response     0.917   0.00620    0.904     0.928       1.76e-17
#> 3     3 response     0.923   0.00543    0.912     0.933       1.76e-17
#> 4     4 response     0.934   0.00549    0.923     0.944       1.76e-17
#> 5     5 response     0.977   0.00840    0.953     0.989       1.76e-17
#> # … with 64 more variables: number_vmail_messages <dbl>,
#> #   total_day_minutes <dbl>, total_day_calls <dbl>, total_eve_minutes <dbl>,
#> #   total_eve_calls <dbl>, total_night_calls <dbl>, total_night_charge <dbl>,
#> #   total_intl_minutes <dbl>, number_customer_service_calls <dbl>,
#> #   state_AL <dbl>, state_AR <dbl>, state_AZ <dbl>, state_CA <dbl>,
#> #   state_CO <dbl>, state_CT <dbl>, state_DC <dbl>, state_DE <dbl>,
#> #   state_FL <dbl>, state_GA <dbl>, state_HI <dbl>, state_IA <dbl>, …

Created on 2022-03-25 by the reprex package (v2.0.1)

Notice that I did not use variables = c("state") and substituted one of the continuous, numeric predictors.

  • Related