Home > Mobile >  How do you get the test error metrics from a `tune_grid` object?
How do you get the test error metrics from a `tune_grid` object?

Time:04-05

I am confused by the output of tune::tune_grid(). Essentially, I would like to get the Residual Mean Squared Errors (rmse's) for any given set of hyperparameters in the grid.

For example, the following code uses 10-fold cross-validation to try 50 different penalty values in a ridge regression.

# Silly data
df <- ISLR::College

# 10 folds
set.seed(42)
cv <- vfold_cv(data = df, v = 10)

# Normalize predictors in a pipeline
recipe <- 
  recipe(formula = Apps ~ ., data = df) %>% 
  step_novel(all_nominal_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

# Ridge regression instance with tuneable `penalty`
ridge_spec <- 
  linear_reg(penalty = tune(), mixture = 0) %>% 
  set_mode("regression") %>% 
  set_engine("glmnet")

# Last two steps in a workflow
ridge_workflow <- workflow() %>% 
  add_recipe(recipe) %>%    # Normalize
  add_model(ridge_spec)     # Fit

# Grid of penalty hyperparameters
penalty_grid <- grid_regular(penalty(range = c(-5, 5)), levels = 50)

# Fit a model per penalty value on 10 folds
ridge_grid <- tune_grid(
  object = ridge_workflow,
  resamples = cv, 
  grid = penalty_grid, 
  control = control_grid(verbose = FALSE)
)

I would like to get the 10 rmse's of the best model.

I thought ridge_grid$.metrics would have this information, but it has 10 tibbles of 10 rows each. What do these mean?

How can I get the 10 rmse's of the best model?

CodePudding user response:

In the ridge_grid$.metrics you get all the holdout performance estimates for each parameter. To get the average metric value for each parameter combination, you can use collect_metric():

estimates <- collect_metrics(ridge_grid)
estimates

# A tibble: 100 × 7
     penalty .metric .estimator     mean     n   std_err .config              
       <dbl> <chr>   <chr>         <dbl> <int>     <dbl> <chr>                
 1 0.00001   rmse    standard   1183.       10 162.      Preprocessor1_Model01
 2 0.00001   rsq     standard      0.913    10   0.00823 Preprocessor1_Model01
 3 0.0000160 rmse    standard   1183.       10 162.      Preprocessor1_Model02
 4 0.0000160 rsq     standard      0.913    10   0.00823 Preprocessor1_Model02
 5 0.0000256 rmse    standard   1183.       10 162.      Preprocessor1_Model03
 6 0.0000256 rsq     standard      0.913    10   0.00823 Preprocessor1_Model03
 7 0.0000409 rmse    standard   1183.       10 162.      Preprocessor1_Model04
 8 0.0000409 rsq     standard      0.913    10   0.00823 Preprocessor1_Model04
 9 0.0000655 rmse    standard   1183.       10 162.      Preprocessor1_Model05
10 0.0000655 rsq     standard      0.913    10   0.00823 Preprocessor1_Model05
# … with 90 more rows

To get the averages of the 10 resamples and best RMSE you can use the following code:

rmse_vals <- 
  estimates %>% 
  dplyr::filter(.metric == "rmse") %>% 
  arrange(desc(mean))
rmse_vals

# A tibble: 50 × 7
   penalty .metric .estimator  mean     n std_err .config              
     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
 1 100000  rmse    standard   3332.    10    309. Preprocessor1_Model50
 2  62506. rmse    standard   3136.    10    306. Preprocessor1_Model49
 3  39069. rmse    standard   2884.    10    302. Preprocessor1_Model48
 4  24421. rmse    standard   2589.    10    295. Preprocessor1_Model47
 5  15264. rmse    standard   2281.    10    285. Preprocessor1_Model46
 6   9541. rmse    standard   1993.    10    272. Preprocessor1_Model45
 7   5964. rmse    standard   1753.    10    258. Preprocessor1_Model44
 8   3728. rmse    standard   1568.    10    243. Preprocessor1_Model43
 9   2330. rmse    standard   1435.    10    227. Preprocessor1_Model42
10   1456. rmse    standard   1342.    10    211. Preprocessor1_Model41
# … with 40 more rows

Which gives you the best RMSE values. When plotting the values you can check if the values are right:

autoplot(ridge_grid, metric = "rmse")

Output:

enter image description here

  • Related