Home > Software engineering >  Plot Legends in a Classification Graph
Plot Legends in a Classification Graph

Time:04-07

I've plotted a 3-D graph in python, showing 3 features (say A, B, C)each coloured by the target variable "y_train", I want to add a legend to the graph to show related categories for each colour, but It only shows one of them "2" any Idea how to fix this?

from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(X_train["A"],X_train["B"],X_train["C"], c = y_train)

ax.set_xlabel("A")
ax.set_ylabel("B")
ax.set_zlabel("C")

ax.legend(y_train)
plt.show()

enter image description here

CodePudding user response:

For automated legend creation, the method PathCollection.legend_elements will do the job. You just have to unpack it and plug it into the legend.

from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot(111, projection='3d')

scatter = ax.scatter(X_train["A"],X_train["B"],X_train["C"], c=y_train)
legend = ax.legend(*scatter.legend_elements())
ax.set_xlabel("A")
ax.set_ylabel("B")
ax.set_zlabel("C")

ax.legend(legend)
plt.show()
  • Related