Using the caret
package in R
, I am performing random forest (RF) regression and part of the analysis is the fine tuning of the model. Because I have only one independent variable, the parameters of the RF I can tune is the maxnodes and the ntrees. I have created a code that iteratively runs several models and exports the results for:
- to find the 'best' maxnodes model,
- to find the 'best' maxtrees model.
By 'best' I mean the model with the lowest RMSE (red circle in the second picture below). The problem is that the output is a list with all the model's RMSE and I have to search manually which model yielded the lowest RMSE. An example of that list which I have to search for the 'best' model:
I was wondering if there is a way to store as a variable the 'best' model for both maxnodes and maxtrees (the two loops I'm using) without my having to search it manually. For example, in the image below, that's the number 9 is value I want to store in a variable:
After I find store the value from the first loop, I am using it in the second loop, at maxnodes =
.
library(caret)
block.data = read.csv("path/block.data.csv")
set.seed(123)
samp <- sample(nrow(block.data), 0.75 * nrow(block.data))
train <- block.data[samp, ]
test <- block.data[-samp, ]
# define the control
trControl = trainControl(method = "repeatedcv",
number = 10,
search = "grid")
rf_default = train(ntl ~ .,
data = train,
method = "rf",
metric = "RMSE",
trControl = trControl)
print(rf_default)
# search best maxnodes
store_maxnode = list()
tuneGrid = expand.grid(.mtry = 1)
for (maxnodes in c(5:15)){
set.seed(456)
rf_maxnode = train(ntl ~ .,
data = train,
method = "rf",
metric = "RMSE",
tuneGrid = tuneGrid,
trControl = trControl,
importance = TRUE,
nodesize = 14,
maxnodes = maxnodes,
ntree = 500)
current_iteration = toString(maxnodes)
store_maxnode[[current_iteration]] = rf_maxnode
}
results_mtry = resamples(store_maxnode) # from here I want to extract the 'best' model
summary(results_mtry)
# search best ntree
store_maxtrees = list()
for (ntree in seq(from = 500, to = 2000, by = 100)) {
set.seed(789)
rf_maxtrees = train(ntl ~ .,
data = train,
method = "rf",
metric = "RMSE",
tuneGrid = tuneGrid,
trControl = trControl,
importance = TRUE,
nodesize = 14,
maxnodes = 10, # this values is from my tests
ntree = ntree)
key = toString(ntree)
store_maxtrees[[key]] = rf_maxtrees
}
results_trees = resamples(store_maxtrees) # from here I want to extract the 'best' model
summary(results_trees)
The structure of the results_mtry
(see the first loop above):
str(results_mtry)
List of 6
$ call : language resamples.default(x = store_maxnode)
$ values :'data.frame': 10 obs. of 34 variables:
..$ Resample : chr [1:10] "Fold01.Rep1" "Fold02.Rep1" "Fold03.Rep1" "Fold04.Rep1" ...
..$ 5~MAE : num [1:10] 7.19 10.76 5.94 9.8 7.19 ...
..$ 5~RMSE : num [1:10] 9.65 14.01 7.83 12.07 9.3 ...
..$ 5~Rsquared : num [1:10] 0.424579 0.009119 0.764428 0.000589 0.47881 ...
..$ 6~MAE : num [1:10] 7.17 10.69 6.08 9.67 7.11 ...
..$ 6~RMSE : num [1:10] 9.62 13.98 7.93 11.95 9.06 ...
..$ 6~Rsquared : num [1:10] 0.42625 0.01011 0.76318 0.00568 0.50504 ...
..$ 7~MAE : num [1:10] 7.15 10.68 6.19 9.62 7.11 ...
..$ 7~RMSE : num [1:10] 9.57 13.99 7.96 11.99 9.02 ...
..$ 7~Rsquared : num [1:10] 0.43125 0.00794 0.76992 0.00555 0.51089 ...
..$ 8~MAE : num [1:10] 7.12 10.79 6.2 9.56 7.03 ...
..$ 8~RMSE : num [1:10] 9.54 14.07 7.85 11.94 8.84 ...
..$ 8~Rsquared : num [1:10] 0.4352 0.0075 0.7752 0.0118 0.5366 ...
..$ 9~MAE : num [1:10] 7.09 10.88 6.21 9.58 6.97 ...
..$ 9~RMSE : num [1:10] 9.46 14.24 7.83 11.92 8.79 ...
..$ 9~Rsquared : num [1:10] 0.4435 0.00364 0.77259 0.01597 0.54392 ...
..$ 10~MAE : num [1:10] 7.14 10.96 6.18 9.64 6.86 ...
..$ 10~RMSE : num [1:10] 9.49 14.28 7.82 11.94 8.68 ...
..$ 10~Rsquared: num [1:10] 0.44116 0.00333 0.77282 0.01974 0.55635 ...
..$ 11~MAE : num [1:10] 7.11 10.97 6.32 9.67 6.85 ...
..$ 11~RMSE : num [1:10] 9.47 14.33 7.91 12.03 8.67 ...
..$ 11~Rsquared: num [1:10] 0.44298 0.00205 0.77089 0.01725 0.56302 ...
..$ 12~MAE : num [1:10] 7.09 11.05 6.21 9.64 6.83 ...
..$ 12~RMSE : num [1:10] 9.4 14.41 7.82 11.99 8.67 ...
..$ 12~Rsquared: num [1:10] 0.4508 0.0016 0.7759 0.0206 0.5624 ...
..$ 13~MAE : num [1:10] 7.2 11.08 6.18 9.67 6.79 ...
..$ 13~RMSE : num [1:10] 9.48 14.48 7.72 12.07 8.57 ...
..$ 13~Rsquared: num [1:10] 0.441933 0.000875 0.779608 0.016986 0.576059 ...
..$ 14~MAE : num [1:10] 7.18 11.16 6.2 9.7 6.86 ...
..$ 14~RMSE : num [1:10] 9.42 14.65 7.74 12.11 8.61 ...
..$ 14~Rsquared: num [1:10] 4.49e-01 6.88e-05 7.73e-01 1.47e-02 5.71e-01 ...
..$ 15~MAE : num [1:10] 7.13 11.3 6.33 9.76 6.88 ...
..$ 15~RMSE : num [1:10] 9.42 14.7 7.78 12.16 8.62 ...
..$ 15~Rsquared: num [1:10] 4.48e-01 1.39e-08 7.69e-01 1.59e-02 5.71e-01 ...
$ models : chr [1:11] "5" "6" "7" "8" ...
$ metrics: chr [1:3] "MAE" "RMSE" "Rsquared"
$ timings:'data.frame': 11 obs. of 3 variables:
..$ Everything: num [1:11] 2.68 1.64 1.17 1.01 1.53 ...
..$ FinalModel: num [1:11] 0.11 0.06 0.06 0.07 0.09 ...
..$ Prediction: num [1:11] NA NA NA NA NA NA NA NA NA NA ...
$ methods: Named chr [1:11] "rf" "rf" "rf" "rf" ...
..- attr(*, "names")= chr [1:11] "5" "6" "7" "8" ...
- attr(*, "class")= chr "resamples"
My dataset:
block.data = structure(list(ntl = c(59.2700004577637, 116.720001220703, 91.8600006103516,
50.3199996948242, 130.020004272461, 97.0800018310547, 69.3300018310547,
49.6699981689453, 46.75, 62.4099998474121, 62.7099990844727,
51.3499984741211, 44.6500015258789, 34.8499984741211, 39.2200012207031,
42.439998626709, 45.4700012207031, 38.7900009155273, 40.2000007629395,
38.8199996948242, 48.2400016784668, 45.7000007629395, 47.810001373291,
53.1300010681152, 49.560001373291, 43.5900001525879, 44.2000007629395,
38.9300003051758, 34.8499984741211, 34.4799995422363, 56.3600006103516,
55, 45.0499992370605, 43.439998626709, 40.9700012207031, 47.0499992370605,
48.0999984741211, 57.9199981689453, 51.310001373291, 38.5400009155273,
43.6800003051758, 50.25, 48.8800010681152, 41.2599983215332,
40.75, 42.9599990844727, 46.6599998474121, 48.3800010681152,
61.8199996948242, 61.0800018310547, 41.2700004577637, 46.1199989318848,
22.8799991607666, 47.3699989318848, 43.6300010681152, 38.4300003051758,
40.8699989318848, 48.7299995422363, 44.4700012207031, 45.4900016784668,
61.6199989318848, 54.9500007629395, 37.9599990844727, 43.439998626709,
41.4000015258789, 47.9900016784668, 41.9700012207031, 37.2299995422363,
42.1800003051758, 49.560001373291, 43.7099990844727, 46.0800018310547,
54.4000015258789, 52.4599990844727, 44.1699981689453, 48.8199996948242,
17.2999992370605, 35.7799987792969, 57.0699996948242, 52.4599990844727,
41.9000015258789, 36.2599983215332, 38.8400001525879, 42.1699981689453,
42.4599990844727, 39.0099983215332, 38.9799995422363, 38.7599983215332,
41.9300003051758, 30.3400001525879, 35.4599990844727, 48.6100006103516,
55.5400009155273, 57.689998626709, 45.6300010681152, 39.0999984741211,
34.5999984741211, 40.5099983215332, 34.0699996948242, 31.5900001525879,
28.6100006103516, 24.7299995422363, 29.7299995422363, 32.1300010681152,
39.5400009155273, 41.7900009155273, 63.9000015258789, 62.810001373291,
49.3499984741211, 49.9000015258789, 47.4700012207031, 43.689998626709,
43.1300010681152, 41.3499984741211, 33.560001373291, 26.7000007629395,
20.1499996185303, 13.960000038147, 18.7399997711182, 42.3899993896484,
41.6500015258789, 57.4500007629395, 60.3300018310547, 61.9599990844727,
61.9199981689453, 48.5800018310547, 41.4099998474121, 43.1100006103516,
45.4500007629395, 52.4500007629395, 58.8699989318848, 44.8400001525879,
28.1599998474121, 15.5299997329712, 11.5100002288818, 16.3199996948242,
34.2099990844727, 49.1100006103516, 54.6599998474121, 52.1500015258789,
52.9300003051758, 58.1500015258789, 52.1100006103516, 43.7299995422363,
46.8699989318848, 53.2900009155273, 57.310001373291, 64.5400009155273,
40.9000015258789, 22.5799999237061, 11.5600004196167, 9.77000045776367,
32.4099998474121, 47.4199981689453, 51.3300018310547, 56.1100006103516,
50.8800010681152, 44.560001373291, 42.4700012207031, 44.1599998474121,
48.75, 60.1699981689453, 49.9500007629395, 37.4000015258789,
36.1699981689453, 53.8800010681152, 58.6599998474121, 61.1500015258789,
50.3699989318848, 36.1800003051758, 42.8600006103516, 42.7999992370605,
46.8699989318848, 48.1800003051758, 44.3699989318848, 30.1000003814697,
52.3699989318848, 63.2599983215332, 62, 52.6699981689453, 41.5299987792969,
48.4900016784668, 46.9500007629395, 41.7000007629395, 44.2000007629395,
38.0999984741211, 27.0400009155273, 40.6300010681152, 49.7099990844727,
49.7099990844727, 39.8800010681152, 43.6599998474121, 46.7000007629395,
39.2200012207031, 35.0999984741211, 26.2900009155273, 37.9700012207031,
43.2299995422363, 40.689998626709, 42.5, 43.0299987792969, 38.4300003051758,
34.3899993896484, 30.6700000762939, 26, 20.8700008392334, 36.9900016784668,
38.25, 57.6800003051758, 52.8400001525879, 41.1100006103516,
39.9300003051758, 31.2600002288818, 27.9400005340576, 30.9599990844727,
32.9799995422363, 22.3099994659424, 35.5099983215332, 38.4700012207031,
53.5800018310547, 78.1100006103516, 62.4500007629395, 45.8400001525879,
42.75, 38.4700012207031, 26.7999992370605, 23.5, 23.5900001525879,
31.7299995422363, 39.8300018310547, 46.439998626709, 64.1500015258789,
63.1800003051758, 56.3300018310547, 50.4099998474121, 47.5, 45.939998626709,
30.1100006103516, 27.5200004577637, 33.8899993896484, 40.5999984741211,
54.2299995422363, 67.9499969482422, 62.1399993896484, 55.5499992370605,
52.8499984741211, 49.439998626709, 51.3899993896484, 39.0200004577637,
29.3700008392334, 34.0800018310547, 43.0699996948242, 49.2299995422363,
59.3499984741211, 59.0499992370605, 63.1599998474121, 60.4900016784668,
47.5800018310547, 53.2700004577637, 40.9300003051758, 13.0600004196167,
23.7399997711182, 34.7299995422363, 44.2099990844727, 51.5299987792969,
58.7799987792969, 61.3300018310547, 51.0099983215332, 54.189998626709,
50.7700004577637, 44.8600006103516, 29.6700000762939, 19.1100006103516,
25.5499992370605, 38.75, 50.75, 52.3600006103516, 53.3400001525879,
52.8600006103516, 54.7599983215332, 51.0800018310547, 46.8800010681152,
34.25, 17.1900005340576, 20.9599990844727, 23.5599994659424,
45.4599990844727, 45.2700004577637, 46.8699989318848, 46.560001373291,
44.5499992370605, 39.7700004577637, 29.5, 11.3299999237061, 21.4899997711182,
16.9099998474121, 42.0699996948242, 40.6800003051758, 44.9099998474121,
43.6800003051758, 42.939998626709, 26.5200004577637, 59.4700012207031,
47.9700012207031, 51.3300018310547, 54.9300003051758, 51.4199981689453,
29.4699993133545, 8.31999969482422, 69.9599990844727, 79.1100006103516,
59.5299987792969, 45.4300003051758, 24.3500003814697, 74.4800033569336,
75.2699966430664, 55.3499984741211, 41.9799995422363, 38.9599990844727,
54.1100006103516, 51.0800018310547, 40.6800003051758, 33.4199981689453,
53.4599990844727, 52.5999984741211, 48.9000015258789, 39.75,
53.3600006103516, 50.8400001525879, 45.3600006103516, 39.0800018310547,
43.9700012207031, 35.1199989318848, 35.6599998474121, 32.9000015258789,
30.3600006103516, 12.0900001525879, 22.2399997711182, 38.5999984741211,
44.8800010681152, 28.7600002288818, 19.3899993896484, 16.0400009155273,
24.0400009155273, 33.5900001525879, 29.6100006103516, 15.2299995422363,
19.0200004577637, 16.6700000762939, 12.3999996185303, 17.4099998474121,
12.7700004577637, 11.539999961853), pop = c(467.730682373047,
442.584808349609, 386.824768066406, 404.851593017578, 393.846649169922,
341.867309570312, 331.978668212891, 420.133941650391, 399.355560302734,
303.836303710938, 308.739990234375, 373.127471923828, 447.762542724609,
266.3720703125, 333.298156738281, 409.553863525391, 380.941284179688,
395.985443115234, 393.493804931641, 414.0380859375, 398.675659179688,
417.473693847656, 455.329650878906, 444.438934326172, 391.074279785156,
332.950469970703, 273.676300048828, 375.670349121094, 391.606079101562,
402.8857421875, 441.001770019531, 468.936431884766, 447.103881835938,
409.662139892578, 355.491485595703, 306.626922607422, 288.12451171875,
264.401550292969, 275.680816650391, 275.232696533203, 322.314422607422,
443.952575683594, 406.246276855469, 325.056121826172, 273.865203857422,
277.731384277344, 282.733306884766, 273.173400878906, 255.424011230469,
252.928161621094, 270.775115966797, 300.873901367188, 330.387023925781,
486.562194824219, 417.766937255859, 328.541778564453, 275.667938232422,
293.250091552734, 310.384307861328, 294.573760986328, 277.082794189453,
264.684020996094, 277.013519287109, 332.2822265625, 386.991516113281,
529.134033203125, 404.797668457031, 337.963989257812, 327.352325439453,
335.318450927734, 357.170532226562, 379.009216308594, 335.018920898438,
313.892242431641, 316.136169433594, 360.021911621094, 177.771179199219,
356.384826660156, 475.251892089844, 477.267425537109, 364.647186279297,
335.457702636719, 331.359771728516, 352.507507324219, 431.358001708984,
495.757507324219, 389.088348388672, 340.287353515625, 314.916931152344,
230.861862182617, 358.379241943359, 447.627899169922, 499.188049316406,
385.820648193359, 331.904357910156, 335.975555419922, 338.039215087891,
384.359283447266, 493.854125976562, 447.729370117188, 345.277893066406,
329.822601318359, 321.495574951172, 292.315155029297, 386.298004150391,
419.256134033203, 460.856658935547, 474.056365966797, 393.099731445312,
331.051147460938, 322.845062255859, 325.789916992188, 353.692932128906,
478.168914794922, 491.732391357422, 355.256286621094, 312.570648193359,
271.149200439453, 254.103805541992, 305.484161376953, 328.351654052734,
454.282196044922, 466.519104003906, 438.591888427734, 362.569458007812,
321.702514648438, 311.144805908203, 318.722686767578, 289.113739013672,
362.393524169922, 530.498779296875, 488.766021728516, 352.353332519531,
291.703857421875, 196.083953857422, 169.829040527344, 276.586456298828,
394.136077880859, 447.649536132812, 357.458465576172, 330.671844482422,
316.614013671875, 308.682250976562, 307.488372802734, 315.810272216797,
315.00830078125, 446.569549560547, 525.409790039062, 469.031463623047,
323.507110595703, 206.240097045898, 136.70866394043, 305.205505371094,
440.419525146484, 416.188415527344, 320.116851806641, 305.065795898438,
312.146514892578, 306.988250732422, 303.505004882812, 297.240936279297,
349.085205078125, 503.130523681641, 436.375579833984, 325.817047119141,
428.442840576172, 374.009552001953, 308.004852294922, 302.125427246094,
305.510589599609, 303.826202392578, 307.065765380859, 310.789855957031,
402.336883544922, 512.603881835938, 341.814819335938, 451.526214599609,
369.636932373047, 312.189910888672, 303.484680175781, 302.448150634766,
305.6005859375, 310.20556640625, 361.661041259766, 486.923309326172,
462.197082519531, 335.660766601562, 434.591369628906, 326.722381591797,
316.817443847656, 310.796447753906, 305.022094726562, 311.223114013672,
315.769500732422, 392.485260009766, 497.828857421875, 365.072479248047,
315.761871337891, 306.943206787109, 305.944671630859, 304.676391601562,
307.842864990234, 322.399810791016, 408.937316894531, 503.714477539062,
331.908355712891, 340.755310058594, 335.776062011719, 310.720001220703,
306.018585205078, 303.140838623047, 314.159790039062, 356.463195800781,
468.256072998047, 481.448883056641, 366.226470947266, 303.771667480469,
312.950714111328, 367.753631591797, 372.05712890625, 316.509735107422,
305.842926025391, 308.080749511719, 324.148742675781, 403.967834472656,
458.983428955078, 315.220581054688, 306.100189208984, 406.476257324219,
387.388549804688, 363.384002685547, 332.023681640625, 316.36376953125,
305.164978027344, 305.761657714844, 332.060546875, 430.189819335938,
437.18359375, 352.428100585938, 439.877349853516, 344.920806884766,
321.906311035156, 303.888641357422, 305.658386230469, 301.303771972656,
309.059204101562, 350.502899169922, 457.789031982422, 404.220825195312,
348.934783935547, 389.165191650391, 319.257232666016, 312.878845214844,
299.004425048828, 295.020935058594, 299.496704101562, 315.326171875,
392.207275390625, 466.010864257812, 360.850219726562, 207.557098388672,
311.1181640625, 364.520080566406, 421.305114746094, 384.428588867188,
322.168792724609, 298.988342285156, 303.18896484375, 350.786376953125,
458.250915527344, 443.151519775391, 324.209350585938, 272.503509521484,
285.845031738281, 320.443328857422, 406.536590576172, 454.334075927734,
401.587219238281, 322.675140380859, 308.545257568359, 409.520935058594,
483.659057617188, 375.045562744141, 255.126037597656, 311.212921142578,
216.756912231445, 368.455963134766, 458.809356689453, 401.597442626953,
325.810699462891, 427.660095214844, 482.688262939453, 331.141510009766,
181.774032592773, 213.352233886719, 187.602554321289, 388.55908203125,
449.115509033203, 357.346282958984, 417.838958740234, 463.803497314453,
350.715362548828, 316.349731445312, 438.630187988281, 402.3486328125,
461.848388671875, 451.344573974609, 387.394012451172, 82.5956802368164,
381.059661865234, 383.909515380859, 463.255798339844, 382.991790771484,
282.866302490234, 395.704193115234, 436.873657226562, 471.699554443359,
365.950927734375, 331.590698242188, 413.536224365234, 461.433929443359,
435.342834472656, 316.814453125, 340.904754638672, 456.206756591797,
458.000915527344, 360.565734863281, 335.498718261719, 432.590454101562,
413.760589599609, 327.622589111328, 356.019195556641, 308.775482177734,
223.692031860352, 330.171447753906, 328.636505126953, 141.314712524414,
174.631240844727, 267.42529296875, 291.692108154297, 282.262054443359,
238.984924316406, 171.093856811523, 278.318084716797, 288.979125976562,
207.751800537109, 171.422134399414, 269.854309082031, 243.924942016602,
90.4415740966797, 160.288192749023, 192.350738525391, 92.9382019042969
)), row.names = c(NA, 353L), class = "data.frame")
CodePudding user response:
Not sure if you know saving the workspace that may help the problem.
save(file = “d:/filename.RData”)
load(file = “d:/filename.RData”)
CodePudding user response:
If results_mtry$metrics
returns a matrix with a column named "RMSE" then the call to produce that column would be: results_mtry$metrics[, "RMSE"]
, so this might return the entire row where that column was a minimum:
results_mtry$metrics[ which.min( results_mtry$metrics[, "RMSE"] , ]
If you had (or later do) provide enough output of str(results_mtry)
and described what parts of it indicated optimums, and then what parts were desired based on those optimums, it would have gotten around the fact that "caret" is a huge download and now there is a note saying it's not available for this version of R. (I suspect that I could get around this quickly if I were at home, but I'm sitting in a hotel room in Thailand and the internet connection is weak.)