Home > Net >  How to customize legend when plotting learning curves in keras?
How to customize legend when plotting learning curves in keras?

Time:08-31

I am plotting the accuracy and loss curves in Tensorflow 2/keras using the history attribute from a CNN model.

def plot_graphs(cnn_trained_model, string):
    plt.plot(cnn_trained_model.history[string])
    plt.plot(cnn_trained_model.history['val_' string])
    plt.xlabel("Epochs")
    plt.ylabel(string)
    plt.legend([string, 'val_' string])
    plt.show()
    
    
  #Plot the accuracy and loss
plot_graphs(cnn_trained_model, "accuracy")
plot_graphs(cnn_trained_model, "loss")

It looks like this. enter image description here

The legend reads 'accuracy' and 'val_accuracy', and I want to change this to 'train_accuracy' and 'validation_accuracy'. Apparently changing plt.legend([string, 'val_' string]) to plt.legend([string, 'validation_' string]) does not work.

I have tried adding the legend manually like so

from matplotlib.patches import Rectangle
...
   extra = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)
   plt.legend([extra, t_line, v_line], ("Accuracy", "test", "validation"))

which does not work either displaying a box with just 'Accuracy' in it and this message:

C:\Users\myusername\AppData\Local\Temp\ipykernel_1472\2335663950.py:10: UserWarning: Legend does not support [<matplotlib.lines.Line2D object at 0x00000227ED9F7D60>] instances.
A proxy artist may be used instead.
See: https://matplotlib.org/users/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists
  plt.legend([extra, t_line, v_line], ("Accuracy", "test", "validation"))
C:\Users\myusername\AppData\Local\Temp\ipykernel_1472\2335663950.py:10: UserWarning: Legend does not support [<matplotlib.lines.Line2D object at 0x00000227EDA050A0>] instances.
A proxy artist may be used instead.
See: https://matplotlib.org/users/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists
  plt.legend([extra, t_line, v_line], ("Accuracy", "test", "validation"))

How can I change the legend please?

CodePudding user response:

As you are using plt.show() inside the function, the graph will be shown at that time. You will need to change it inside the function. One wasy ways would be to change the line..

    plt.legend([string, 'val_' string])

to

    if string == 'accuracy':
        plt.legend(['train_accuracy', 'validation_accuracy'])
    else:
        plt.legend([string, 'val_' string])

... it should work.

Plot with random data...

enter image description here

enter image description here

  • Related