I train a RandomForestClassifier
as
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
X, y = make_classification()
clf = RandomForestClassifier()
clf.fit(X,y)
where X
and y
are some feature vectors and labels.
Once the fit is done, I can e.g. list the depth of all trees grown for each estimator in the forest as follows:
[estimator.tree_.max_depth for estimator in clf.estimators_]
Now I would like to find out all other public variables (apart from max_depth
) a tree_
within an estimator
stores. So I tried:
vars(clf.estimators_[0].tree_)
but unfortunately this does not work and returns the error
TypeError: vars() argument must have __dict__ attribute
What syntax can I use to successfully list all public variables in a estimator.tree_
?
CodePudding user response:
There is no way to get this attributes automatically but the documentation of Tree
class give you all attributes:
- capacity
- children_left
- children_right
- feature
- impurity
- max_depth
- max_n_classes
- n_classes
- n_features
- n_leaves
- n_node_samples
- n_outputs
- node_count
- threshold
- value
- weighted_n_node_samples
To know more: https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html