I am plotting the accuracy and loss curves in Tensorflow 2/keras using the history attribute from a CNN model.
def plot_graphs(cnn_trained_model, string):
plt.plot(cnn_trained_model.history[string])
plt.plot(cnn_trained_model.history['val_' string])
plt.xlabel("Epochs")
plt.ylabel(string)
plt.legend([string, 'val_' string])
plt.show()
#Plot the accuracy and loss
plot_graphs(cnn_trained_model, "accuracy")
plot_graphs(cnn_trained_model, "loss")
The legend reads 'accuracy' and 'val_accuracy', and I want to change this to 'train_accuracy' and 'validation_accuracy'. Apparently changing plt.legend([string, 'val_' string]) to plt.legend([string, 'validation_' string]) does not work.
I have tried adding the legend manually like so
from matplotlib.patches import Rectangle
...
extra = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)
plt.legend([extra, t_line, v_line], ("Accuracy", "test", "validation"))
which does not work either displaying a box with just 'Accuracy' in it and this message:
C:\Users\myusername\AppData\Local\Temp\ipykernel_1472\2335663950.py:10: UserWarning: Legend does not support [<matplotlib.lines.Line2D object at 0x00000227ED9F7D60>] instances.
A proxy artist may be used instead.
See: https://matplotlib.org/users/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists
plt.legend([extra, t_line, v_line], ("Accuracy", "test", "validation"))
C:\Users\myusername\AppData\Local\Temp\ipykernel_1472\2335663950.py:10: UserWarning: Legend does not support [<matplotlib.lines.Line2D object at 0x00000227EDA050A0>] instances.
A proxy artist may be used instead.
See: https://matplotlib.org/users/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists
plt.legend([extra, t_line, v_line], ("Accuracy", "test", "validation"))
How can I change the legend please?
CodePudding user response:
As you are using plt.show() inside the function, the graph will be shown at that time. You will need to change it inside the function. One wasy ways would be to change the line..
plt.legend([string, 'val_' string])
to
if string == 'accuracy':
plt.legend(['train_accuracy', 'validation_accuracy'])
else:
plt.legend([string, 'val_' string])
... it should work.
Plot with random data...