Home > OS >  add precision-recall curves to plot using a function
add precision-recall curves to plot using a function

Time:12-17

I have a dataframe with precision and recall for many classifiers, each one ran with 4 different values of the confidence threshold:

    MODEL   CONFIDENCE_THR  PRECISION   RECALL
0   Model1  0.25    0.992647    0.950704
1   Model1  0.45    1.000000    0.929577
2   Model1  0.35    0.992537    0.936620
3   Model1  0.30    0.992593    0.943662
4   Model2  0.45    0.992647    0.950704
5   Model2  0.30    0.992647    0.950704
6   Model2  0.35    0.992647    0.950704
7   Model2  0.25    0.992701    0.957746
8   Model3  0.30    0.978417    0.957746
9   Model3  0.35    0.978102    0.943662
.
.
.

I would like to create a Matplotlib plot in JupyterLab, and add a precision-recall curve for each model. Since the list of models may change in the future, I would like to do this with a Python function, instead than hardcoding the model names in the Matplotlib code. I wrote something like

import pandas as pd
import matplotlib.pyplot as plt
%matplotlib widget

df = pd.read_csv("results.csv")

plt.title('Precision-Recall curves')
plt.legend(loc = 'lower right')

def plot_precision_recall_curve(df, model, plt):
    df = df.loc[df['MODEL'] == model, ['MODEL', 'PRECISION', 'RECALL']]
    plt.plot(df['RECALL'], df["PRECISION"], 'b', label = model)  

plot_precision_recall_curve(df, 'Model1', plt)
plt.show()

But I get an empty plot, with the message

No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.

Also, note that in my code I tried to plot the precision-recall curve of a single model, but actually I would like to add to the same plot the precision-recall curves of all models included in the dataframe. Can you help me?

CodePudding user response:

To fix your code, move plt.legend() after you plot something, best to be before plt.show():

plt.title('Precision-Recall curves')
plot_precision_recall_curve(df, 'Model1', plt)
plt.legend(loc = 'lower right')
plt.show()

On the other note, are you open to seaborn:

import seaborn as sns
sns.lineplot(data=df, x='RECALL', y='PRECISION', hue='MODEL')

Output:

enter image description here

  • Related