Home > Blockchain >  Can't display bar plot with SHAP
Can't display bar plot with SHAP

Time:11-26

I'm new to SHAP and trying to use it on top of my RandomForestClassifier. Here's the code snippet after I already ran clf.fit(train_x, train_y):

explainer = shap.Explainer(clf)
shap_values = explainer(train_x.to_numpy()[0:5, :])
shap.summary_plot(shap_values, plot_type='bar')

Here's the resulting plot: enter image description here

Now, there's two problems with this. One is that it is not a bar plot even though I set the plot_type parameter. The other is that I've seemed to lost my feature names somehow (and yes they do exist on the dataframes when calling clf.fit()).

I tried replacing the last line with:

shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], plot_type='bar')

And that changed nothing. I also tried to replace it with the following to see if I could at least recover my feature names:

shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], feature_names=list(train_x.columns.values), plot_type='bar')

But that threw an error:

Traceback (most recent call last):
  File "sklearn_model_runs.py", line 41, in <module>
    main()
  File "sklearn_model_runs.py", line 38, in main
    shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], feature_names=list(train_x.columns.values), plot_type='bar')
  File "C:\Users\kapoo\anaconda3\envs\sci\lib\site-packages\shap\plots\_beeswarm.py", line 554, in summary_legacy
    feature_names=feature_names[sort_inds],
TypeError: only integer scalar arrays can be converted to a scalar index

I'm kind of at a loss at this point. I just tried it with 5 rows of the training set but want to use the whole thing once I get past this stumbling block. If it helps, the classifier had 5 labels and my SHAP version is 0.40.0.

CodePudding user response:

Alright, here was the problem. Replace this:

shap_values = explainer(train_x.to_numpy()[0:5, :])

With this:

shap_values = explainer.shap_values(train_x) # Use whole thing as dataframe

Then you can use this during plotting:

feature_names=list(train_x.columns.values)

The documentation here should really be updated...

  • Related