Home > Blockchain >  How to make decision boundary plots from RandomForest in R
How to make decision boundary plots from RandomForest in R

Time:12-22

I would like to make a decision boundary plot in R, similar to the one below, representing decision boundaries for classification results from a Random Forest. I would like to be able to tweak and customize it, in something like ggplot.

enter image description here

This example is from RF classification of three classes from three predictors PC. I know that the starting points are scatterplots, the hard part is to recreate the maps. Appreciate any pointers

CodePudding user response:

We don't need any additional packages here. We can draw the boundary using only ggplot and randomForest.

Suppose we use the simple example of classifying the iris data set according to petal width and length:

library(randomForest)

rf <- randomForest(Species ~ Petal.Length   Petal.Width, data = iris, 
                   proximity = TRUE)

Now we just create a grid of all x, y values in the range of the data using expand.grid, and get the predicted species at each pair:

df <- expand.grid(Petal.Width = seq(0, 3, length.out = 100),
                  Petal.Length = seq(0, 7, length.out = 100))

df$Species <- predict(rf, df)

Now we plot the boundary with geom_raster

library(ggplot2)

ggplot(iris, aes(Petal.Width, Petal.Length, fill = Species))  
  geom_raster(data = df, alpha = 0.5)  
  geom_point(shape = 21, size = 3)  
  theme_minimal()

Created on 2022-12-21 with reprex v2.0.2

  • Related