I am trying to plot some data from a csv file. I used the Pandas to load the csv file. I am using sns.lineplot()
to plot the lines. But one of the legend is always faulty. It shows a square around one of the legend.
plt.figure(dpi=150)
lin1 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="tanh"])
lin2 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="silu"])
lin3 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="swish"])
plt.xlabel("Training time (sec)")
plt.legend(("tanh", "silu", "swish"))
plt.yscale('log',base=10)
I used 3 different functions because there are more Activations
. This is the resulting plot.
The plot is looking correct but the legend is creating problems. Here are versions of the plotting tools that I am using.
Python 3.9.12
matplotlib 3.6.1
matplotlib-inline 0.1.6
seaborn 0.12.1
I could not find the same issue on Internet. A kernel restart isn't helping. Please let me know if more information is needed.
CodePudding user response:
Try adding the label argument to the individual lineplots and then just call legend without any arguments:
plt.figure(dpi=150)
lin1 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="tanh"], label = 'tanh')
lin2 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="silu"], label = 'silu')
lin3 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="swish"], label = 'swish')
plt.xlabel("Training time (sec)")
plt.legend()
plt.yscale('log',base=10)
CodePudding user response:
You can also plot all your lines with a single command by using hue
:
sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"].isin(["tanh", "silu", "swish"])], hue="Activation")