I get an error when running predict in the tidymodels framework. The error appears to be related to selecting variables in the recipe (see code below).

What I've tried

There are some related SO posts, such as this one, this one, or this one, but they seem to deal with different issues (such as manipulating the outcome in the recipe).

However, I would like to understand why my code throws the error in the first place.

Code that throws error


d_train <- mtcars %>% slice(1:20)
d_test <- mtcars %>% slice(21:nrow(mtcars))
preds_chosen <- c("hp", "disp", "am")
rec1 <- 
  recipe( ~ ., data = d_train) %>% 
  step_select(all_of(preds_chosen), mpg) %>% 
  update_role(all_of(preds_chosen), new_role = "predictor") %>% 
  update_role(mpg, new_role = "outcome")

model_lm <- linear_reg()
wf1 <-
  workflow() %>% 
  add_model(model_lm) %>% 
lm_fit1 <-
  wf1 %>% 
preds <-
  lm_fit1 %>% 
#> Error in `dplyr::select()`:
#> ! Can't subset columns that don't exist.
#> ✖ Column `mpg` doesn't exist.

Possible solution

If I change the recipe in the following ways, the whole code runs without an error:

rec2 <- recipe(mpg ~ hp   disp   am, data = d_train)
rec3 <- 
  recipe(mpg ~ ., data = d_train) %>% 
  update_role(all_predictors(), new_role = "id") %>% 
  update_role(all_of(preds_chosen), new_role = "predictor") %>% 
  update_role(mpg, new_role = "outcome")


#> R version 4.1.3 (2022-03-10)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Big Sur/Monterey 10.16
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> loaded via a namespace (and not attached):
#>  [1] rstudioapi_0.13 knitr_1.39      magrittr_2.0.3  rlang_1.0.2    
#>  [5] fastmap_1.1.0   fansi_1.0.3     stringr_1.4.0   styler_1.5.1   
#>  [9] highr_0.9       tools_4.1.3     xfun_0.30       utf8_1.2.2     
#> [13] cli_3.3.0       withr_2.5.0     htmltools_0.5.2 ellipsis_0.3.2 
#> [17] yaml_2.3.5      digest_0.6.29   tibble_3.1.7    lifecycle_1.0.1
#> [21] crayon_1.5.1    purrr_0.3.4     vctrs_0.4.1     fs_1.5.2       
#> [25] glue_1.6.2      evaluate_0.15   rmarkdown_2.14  reprex_2.0.1   
#> [29] stringi_1.7.6   compiler_4.1.3  pillar_1.7.0    backports_1.4.1
#> [33] pkgconfig_2.0.3
Created on 2022-05-21 by the reprex package (v2.0.1)

CodePudding user response:

The answer about how to use skip = TRUE may work for your situation, but if you are using very simple preprocessing and don't really need the recipe except for specifying roles, you may want to look into using add_variables():


d_train <- mtcars %>% slice(1:20)
d_test <- mtcars %>% slice(21:nrow(mtcars))
preds_chosen <- c("hp", "disp", "am")

wf1 <-
  workflow() %>% 
  add_model(linear_reg()) %>% 
  add_variables(outcomes = mpg, predictors = !! preds_chosen)

lm_fit1 <- fit(wf1, d_train)
predict(lm_fit1, d_test)
#> # A tibble: 12 × 1
#>    .pred
#>    <dbl>
#>  1  22.6
#>  2  17.2
#>  3  17.4
#>  4  12.1
#>  5  14.9
#>  6  28.2
#>  7  26.3
#>  8  25.6
#>  9  14.6
#> 10  21.8
#> 11  11.7
#> 12  25.4

Created on 2022-05-22 by the reprex package (v2.0.1)

CodePudding user response:

We may use skip = TRUE in skip_select. According to ?skip_select

skip - A logical. Should the step be skipped when the recipe is baked by bake()? While all operations are baked when prep() is run, some operations may not be able to be conducted on new data (e.g. processing the outcome variable(s)). Care should be taken when using skip = TRUE as it may affect the computations for subsequent operations.

rec1 <- 
  recipe( ~ ., data = d_train) %>% 
  step_select(all_of(preds_chosen), mpg, skip = TRUE) %>% 
  update_role(all_of(preds_chosen), new_role = "predictor") %>% 
  update_role(mpg, new_role = "outcome")

and then using the OP's code

> preds <-
    lm_fit1 %>% 
> preds
# A tibble: 12 × 1
 1  22.6
 2  17.2
 3  17.4
 4  12.1
 5  14.9
 6  28.2
 7  26.3
 8  25.6
 9  14.6
10  21.8
11  11.7
12  25.4
