Home > Back-end >  how to set legend to a matplolib Axes in a subfigure plot
how to set legend to a matplolib Axes in a subfigure plot

Time:05-21

I am trying to plot the decision boundary of an SVM classifier using iris dataset. The class label doesn't appear on the legend although I set label=y.

Code:

import matplotlib.pyplot as plt 
from sklearn import svm, datasets 
from sklearn.inspection import DecisionBoundaryDisplay

iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target

linear = svm.LinearSVC()
linear.fit(X,y)

X0, X1 = X[:, 0], X[:, 1]

fig, ax = plt.subplots(figsize=(10, 6))
disp = DecisionBoundaryDisplay.from_estimator(linear, X,
    response_method='predict',cmap=plt.cm.coolwarm, alpha=.8,ax=ax,
    xlabel=iris.feature_names[0],ylabel=iris.feature_names[1],label=y)
ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
ax.set_xticks(())
ax.set_yticks(())
ax.set_title('Some title') 
ax.legend()
plt.show()

Figure: enter image description here

CodePudding user response:

U could use enter image description here

CodePudding user response:

Running your code I get this warning:

UserWarning: The following kwargs were not used by contour: 'label'

This is caused by passing label=y to disp = DecisionBoundaryDisplay.from_estimator

If you want to show a legend I suggest to work with the scatters, something like this:

disp = DecisionBoundaryDisplay.from_estimator(linear, X,
    response_method='predict',cmap=plt.cm.coolwarm, alpha=.8,ax=ax,
    xlabel=iris.feature_names[0],ylabel=iris.feature_names[1])

classes = sorted(list(set(y)))
for c in classes:
    ax.scatter(X0[y == c], X1[y == c], color=plt.cm.coolwarm(c / max(classes)), s=20, edgecolors='k', label=c)
  • Related