Problem
I get an error when running predict
in the tidymodels framework.
The error appears to be related to selecting variables in the recipe (see code below).
What I've tried
There are some related SO posts, such as this one, this one, or this one, but they seem to deal with different issues (such as manipulating the outcome in the recipe).
However, I would like to understand why my code throws the error in the first place.
Code that throws error
library(tidyverse)
library(tidymodels)
data("mtcars")
d_train <- mtcars %>% slice(1:20)
d_test <- mtcars %>% slice(21:nrow(mtcars))
preds_chosen <- c("hp", "disp", "am")
rec1 <-
recipe( ~ ., data = d_train) %>%
step_select(all_of(preds_chosen), mpg) %>%
update_role(all_of(preds_chosen), new_role = "predictor") %>%
update_role(mpg, new_role = "outcome")
model_lm <- linear_reg()
wf1 <-
workflow() %>%
add_model(model_lm) %>%
add_recipe(rec1)
lm_fit1 <-
wf1 %>%
fit(d_train)
preds <-
lm_fit1 %>%
predict(d_test)
#> Error in `dplyr::select()`:
#> ! Can't subset columns that don't exist.
#> ✖ Column `mpg` doesn't exist.
Possible solution
If I change the recipe in the following ways, the whole code runs without an error:
rec2 <- recipe(mpg ~ hp disp am, data = d_train)
rec3 <-
recipe(mpg ~ ., data = d_train) %>%
update_role(all_predictors(), new_role = "id") %>%
update_role(all_of(preds_chosen), new_role = "predictor") %>%
update_role(mpg, new_role = "outcome")
SessionInfo
sessionInfo()
#> R version 4.1.3 (2022-03-10)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Big Sur/Monterey 10.16
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib
#>
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> loaded via a namespace (and not attached):
#> [1] rstudioapi_0.13 knitr_1.39 magrittr_2.0.3 rlang_1.0.2
#> [5] fastmap_1.1.0 fansi_1.0.3 stringr_1.4.0 styler_1.5.1
#> [9] highr_0.9 tools_4.1.3 xfun_0.30 utf8_1.2.2
#> [13] cli_3.3.0 withr_2.5.0 htmltools_0.5.2 ellipsis_0.3.2
#> [17] yaml_2.3.5 digest_0.6.29 tibble_3.1.7 lifecycle_1.0.1
#> [21] crayon_1.5.1 purrr_0.3.4 vctrs_0.4.1 fs_1.5.2
#> [25] glue_1.6.2 evaluate_0.15 rmarkdown_2.14 reprex_2.0.1
#> [29] stringi_1.7.6 compiler_4.1.3 pillar_1.7.0 backports_1.4.1
#> [33] pkgconfig_2.0.3
Created on 2022-05-21 by the reprex package (v2.0.1)
```
CodePudding user response:
The answer about how to use skip = TRUE
may work for your situation, but if you are using very simple preprocessing and don't really need the recipe except for specifying roles, you may want to look into using add_variables()
:
library(tidyverse)
library(tidymodels)
data("mtcars")
d_train <- mtcars %>% slice(1:20)
d_test <- mtcars %>% slice(21:nrow(mtcars))
preds_chosen <- c("hp", "disp", "am")
wf1 <-
workflow() %>%
add_model(linear_reg()) %>%
add_variables(outcomes = mpg, predictors = !! preds_chosen)
lm_fit1 <- fit(wf1, d_train)
predict(lm_fit1, d_test)
#> # A tibble: 12 × 1
#> .pred
#> <dbl>
#> 1 22.6
#> 2 17.2
#> 3 17.4
#> 4 12.1
#> 5 14.9
#> 6 28.2
#> 7 26.3
#> 8 25.6
#> 9 14.6
#> 10 21.8
#> 11 11.7
#> 12 25.4
Created on 2022-05-22 by the reprex package (v2.0.1)
CodePudding user response:
We may use skip = TRUE
in skip_select
. According to ?skip_select
skip - A logical. Should the step be skipped when the recipe is baked by bake()? While all operations are baked when prep() is run, some operations may not be able to be conducted on new data (e.g. processing the outcome variable(s)). Care should be taken when using skip = TRUE as it may affect the computations for subsequent operations.
rec1 <-
recipe( ~ ., data = d_train) %>%
step_select(all_of(preds_chosen), mpg, skip = TRUE) %>%
update_role(all_of(preds_chosen), new_role = "predictor") %>%
update_role(mpg, new_role = "outcome")
and then using the OP's code
> preds <-
lm_fit1 %>%
predict(d_test)
> preds
# A tibble: 12 × 1
.pred
<dbl>
1 22.6
2 17.2
3 17.4
4 12.1
5 14.9
6 28.2
7 26.3
8 25.6
9 14.6
10 21.8
11 11.7
12 25.4