Home > Net >  scikit-learn RandomForestClassifier list all variables of an estimator tree?
scikit-learn RandomForestClassifier list all variables of an estimator tree?

Time:01-20

I train a RandomForestClassifier as

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
X, y = make_classification()
clf = RandomForestClassifier()
clf.fit(X,y)

where X and y are some feature vectors and labels.

Once the fit is done, I can e.g. list the depth of all trees grown for each estimator in the forest as follows:

[estimator.tree_.max_depth for estimator in clf.estimators_]

Now I would like to find out all other public variables (apart from max_depth) a tree_ within an estimator stores. So I tried:

vars(clf.estimators_[0].tree_)

but unfortunately this does not work and returns the error

TypeError: vars() argument must have __dict__ attribute

What syntax can I use to successfully list all public variables in a estimator.tree_?

CodePudding user response:

There is no way to get this attributes automatically but the documentation of Tree class give you all attributes:

  • capacity
  • children_left
  • children_right
  • feature
  • impurity
  • max_depth
  • max_n_classes
  • n_classes
  • n_features
  • n_leaves
  • n_node_samples
  • n_outputs
  • node_count
  • threshold
  • value
  • weighted_n_node_samples

To know more: https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html

  • Related