Currently using the tidymodels
framework and struggling to understand some differences in model predictions and performance results I get, specifically when I use both fit
and predict
on the exact same dataset (i.e. the dataset the model was trained on).
Below's a reproducible example - I'm using the cells dataset and training a random-forest on the data (rf_fit
). The object rf_fit$fit$predictions
is one of the sets of predictions I assess the accuracy of. I then use rf_fit
to make predictions on the same data via the predict
function (yielding rf_training_pred
, the other set of predictions I assess the accuracy of).
My question is - why are these sets of predictions different from each other? And why are they so different?
I presume something must be going on under the hood I'm not aware off, but I'd expected these to be identical, as I'd assumed that fit()
trained a model (and has some predictions associated with this trained model) and then predict()
takes that exact model and just re-applies it to (in this case) the same data - hence the predictions of both should be identical.
What am I missing? Any suggestions or help in understanding would be hugely appreciated - thanks in advance!
# Load required libraries
library(tidymodels); library(modeldata)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
# Set seed
set.seed(123)
# Split up data into training and test
data(cells, package = "modeldata")
# Define Model
rf_mod <- rand_forest(trees = 1000) %>%
set_engine("ranger") %>%
set_mode("classification")
# Fit the model to training data and then predict on same training data
rf_fit <- rf_mod %>%
fit(class ~ ., data = cells)
rf_training_pred <- rf_fit %>%
predict(cells, type = "prob")
# Evaluate accuracy
data.frame(rf_fit$fit$predictions) %>%
bind_cols(cells %>% select(class)) %>%
roc_auc(truth = class, PS)
#> # A tibble: 1 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.903
rf_training_pred %>%
bind_cols(cells %>% select(class)) %>%
roc_auc(truth = class, .pred_PS)
#> # A tibble: 1 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 1.00
Created on 2021-09-25 by the reprex package (v2.0.1)
CodePudding user response:
When you are using fit
, you are fitting a model to your data (cells
). The model fit will never be perfect, hence your accuracy < 1.
When you use predict
, you are taking this model and try to predict the values of your cells
data with that model. See the tautologic procedure here? Since the model you are using to predict the cells
data comes from the cells
data, you get a perfect accuracy of 1.
CodePudding user response:
First off, look at the documentation for what ranger::ranger()
returns, especially what predictions
is:
Predicted classes/values, based on out of bag samples (classification and regression only).
This isn't the same as what you get when predicting on the final whole fitted model.
Second, when you do predict on the final model, you get the same thing whether you predict on the tidymodels object or the underlying ranger object.
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(modeldata)
data(cells, package = "modeldata")
cells <- cells %>% select(-case)
# Define Model
rf_mod <- rand_forest(trees = 1000) %>%
set_engine("ranger") %>%
set_mode("classification")
# Fit the model to training data and then predict on same training data
rf_fit <- rf_mod %>%
fit(class ~ ., data = cells)
tidymodels_results <- predict(rf_fit, cells, type = "prob")
tidymodels_results
#> # A tibble: 2,019 × 2
#> .pred_PS .pred_WS
#> <dbl> <dbl>
#> 1 0.929 0.0706
#> 2 0.764 0.236
#> 3 0.222 0.778
#> 4 0.920 0.0796
#> 5 0.961 0.0386
#> 6 0.0486 0.951
#> 7 0.101 0.899
#> 8 0.954 0.0462
#> 9 0.293 0.707
#> 10 0.405 0.595
#> # … with 2,009 more rows
ranger_results <- predict(rf_fit$fit, cells, type = "response")
as_tibble(ranger_results$predictions)
#> # A tibble: 2,019 × 2
#> PS WS
#> <dbl> <dbl>
#> 1 0.929 0.0706
#> 2 0.764 0.236
#> 3 0.222 0.778
#> 4 0.920 0.0796
#> 5 0.961 0.0386
#> 6 0.0486 0.951
#> 7 0.101 0.899
#> 8 0.954 0.0462
#> 9 0.293 0.707
#> 10 0.405 0.595
#> # … with 2,009 more rows
Created on 2021-09-25 by the reprex package (v2.0.1)
NOTE: this only works because we have used very simple preprocessing. As we note here you generally should not predict on the underlying $fit
object.