Is it possible to find the path/rules for each node? I want to extract the rules for each node, and not just for the terminal nodes.
Example:
library(partykit)
X <- MASS::mvrnorm(n, rep(0, p), diag(p))
y <- as.numeric(drop(X %*% rep(1, p)) > 2)
data <- data.frame(y, X)
tree <- rpart(y ~ .,
data = data,
control = rpart.control(cp = 0.005))
pfit <- as.party(tree)
I can use partykit:::.list.rules.party(pfit)
but this return the rules of the terminal nodes. I'm looking for the rules for each node.
CodePudding user response:
Set the argument i = ...
to specify all the node IDs for which you want the rules. With nodeids()
you can extract all node IDs (by default):
R> partykit:::.list.rules.party(pfit, i = nodeids(pfit))
1
""
2
"X3 < 0.650618460125409"
3
"X3 < 0.650618460125409 & X2 < 1.62837615944647"
4
"X3 < 0.650618460125409 & X2 < 1.62837615944647 & X4 < 1.38485264813313"
...