Home > Blockchain >  Get waterfall plot values of a feature in a dataframe using shap package
Get waterfall plot values of a feature in a dataframe using shap package

Time:04-07

I am working on a binary classification using random forest model, neural networks in which am using SHAP to explain the model predictions. I followed the tutorial and wrote the below code to get the waterfall plot shown below

With the help of Sergey Bushmanaov's SO post enter image description here

CodePudding user response:

Try following:

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer
from shap import TreeExplainer, Explanation
from shap.plots import waterfall

import shap
print(shap.__version__)

X, y = load_breast_cancer(return_X_y=True, as_frame=True)
model = RandomForestClassifier(max_depth=5, n_estimators=100).fit(X, y)
explainer = TreeExplainer(model)
sv = explainer(X)
exp = Explanation(sv.values[:,:,1], 
                  sv.base_values[:,1], 
                  data=X.values, 
                  feature_names=X.columns)
idx = 0
waterfall(exp[idx])

0.39.0

enter image description here

Then:

pd.DataFrame({
    'row_id':idx,
    'feature': X.columns,
    'feature_value': exp[idx].values,
    'base_value': exp[idx].base_values,
    'shap_values': exp[idx].values
})

#expected output
row_id  feature feature_value   base_value  shap_values
0   0   mean radius -0.035453   0.628998    -0.035453
1   0   mean texture    0.047571    0.628998    0.047571
2   0   mean perimeter  -0.036218   0.628998    -0.036218
3   0   mean area   -0.041276   0.628998    -0.041276
4   0   mean smoothness -0.006842   0.628998    -0.006842
5   0   mean compactness    -0.009275   0.628998    -0.009275
6   0   mean concavity  -0.035188   0.628998    -0.035188
7   0   mean concave points -0.051165   0.628998    -0.051165
8   0   mean symmetry   -0.002192   0.628998    -0.002192
9   0   mean fractal dimension  0.001521    0.628998    0.001521
10  0   radius error    -0.021223   0.628998    -0.021223
11  0   texture error   -0.000470   0.628998    -0.000470
12  0   perimeter error -0.021423   0.628998    -0.021423
13  0   area error  -0.035313   0.628998    -0.035313
14  0   smoothness error    -0.000060   0.628998    -0.000060
15  0   compactness error   0.001053    0.628998    0.001053
16  0   concavity error -0.002988   0.628998    -0.002988
17  0   concave points error    0.000140    0.628998    0.000140
18  0   symmetry error  0.001238    0.628998    0.001238
19  0   fractal dimension error -0.001097   0.628998    -0.001097
20  0   worst radius    -0.050027   0.628998    -0.050027
21  0   worst texture   0.038056    0.628998    0.038056
22  0   worst perimeter -0.079717   0.628998    -0.079717
23  0   worst area  -0.072312   0.628998    -0.072312
24  0   worst smoothness    -0.006917   0.628998    -0.006917
25  0   worst compactness   -0.016184   0.628998    -0.016184
26  0   worst concavity -0.022500   0.628998    -0.022500
27  0   worst concave points    -0.088697   0.628998    -0.088697
28  0   worst symmetry  -0.026166   0.628998    -0.026166
29  0   worst fractal dimension -0.007683   0.628998    -0.007683

RandomForest is a bit special, this is why. When something fails with the new plots API, try to feed Explanation object.

  • Related