I am working on a binary classification using random forest model, neural networks in which am using SHAP to explain the model predictions. I followed the tutorial and wrote the below code to get the waterfall plot shown below
With the help of Sergey Bushmanaov's SO post
CodePudding user response:
Try following:
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from shap import TreeExplainer, Explanation
from shap.plots import waterfall
import shap
print(shap.__version__)
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
model = RandomForestClassifier(max_depth=5, n_estimators=100).fit(X, y)
explainer = TreeExplainer(model)
sv = explainer(X)
exp = Explanation(sv.values[:,:,1],
sv.base_values[:,1],
data=X.values,
feature_names=X.columns)
idx = 0
waterfall(exp[idx])
0.39.0
Then:
pd.DataFrame({
'row_id':idx,
'feature': X.columns,
'feature_value': exp[idx].values,
'base_value': exp[idx].base_values,
'shap_values': exp[idx].values
})
#expected output
row_id feature feature_value base_value shap_values
0 0 mean radius -0.035453 0.628998 -0.035453
1 0 mean texture 0.047571 0.628998 0.047571
2 0 mean perimeter -0.036218 0.628998 -0.036218
3 0 mean area -0.041276 0.628998 -0.041276
4 0 mean smoothness -0.006842 0.628998 -0.006842
5 0 mean compactness -0.009275 0.628998 -0.009275
6 0 mean concavity -0.035188 0.628998 -0.035188
7 0 mean concave points -0.051165 0.628998 -0.051165
8 0 mean symmetry -0.002192 0.628998 -0.002192
9 0 mean fractal dimension 0.001521 0.628998 0.001521
10 0 radius error -0.021223 0.628998 -0.021223
11 0 texture error -0.000470 0.628998 -0.000470
12 0 perimeter error -0.021423 0.628998 -0.021423
13 0 area error -0.035313 0.628998 -0.035313
14 0 smoothness error -0.000060 0.628998 -0.000060
15 0 compactness error 0.001053 0.628998 0.001053
16 0 concavity error -0.002988 0.628998 -0.002988
17 0 concave points error 0.000140 0.628998 0.000140
18 0 symmetry error 0.001238 0.628998 0.001238
19 0 fractal dimension error -0.001097 0.628998 -0.001097
20 0 worst radius -0.050027 0.628998 -0.050027
21 0 worst texture 0.038056 0.628998 0.038056
22 0 worst perimeter -0.079717 0.628998 -0.079717
23 0 worst area -0.072312 0.628998 -0.072312
24 0 worst smoothness -0.006917 0.628998 -0.006917
25 0 worst compactness -0.016184 0.628998 -0.016184
26 0 worst concavity -0.022500 0.628998 -0.022500
27 0 worst concave points -0.088697 0.628998 -0.088697
28 0 worst symmetry -0.026166 0.628998 -0.026166
29 0 worst fractal dimension -0.007683 0.628998 -0.007683
RandomForest
is a bit special, this is why. When something fails with the new plots API, try to feed Explanation
object.