Home > other >  find the path/rules for each node partykit
find the path/rules for each node partykit

Time:10-24

Is it possible to find the path/rules for each node? I want to extract the rules for each node, and not just for the terminal nodes.

Example:

library(partykit)
X    <- MASS::mvrnorm(n, rep(0, p), diag(p))
y    <- as.numeric(drop(X %*% rep(1, p)) > 2)

data <- data.frame(y, X)

tree      <- rpart(y ~ .,
                   data = data,
                   control = rpart.control(cp = 0.005))
pfit      <- as.party(tree)

I can use partykit:::.list.rules.party(pfit) but this return the rules of the terminal nodes. I'm looking for the rules for each node.

CodePudding user response:

Set the argument i = ... to specify all the node IDs for which you want the rules. With nodeids() you can extract all node IDs (by default):

R> partykit:::.list.rules.party(pfit, i = nodeids(pfit))
                                                                          1
                                                                         "" 
                                                                          2 
                                                   "X3 < 0.650618460125409" 
                                                                          3 
                           "X3 < 0.650618460125409 & X2 < 1.62837615944647" 
                                                                          4 
   "X3 < 0.650618460125409 & X2 < 1.62837615944647 & X4 < 1.38485264813313" 
...
  • Related