Home > Blockchain >  PR-ROC curve with cross validation strange behaviour
PR-ROC curve with cross validation strange behaviour

Time:02-20

Inspired by strange PR-ROC curve

The code is below:

rskf = RepeatedStratifiedKFold(n_splits=5, n_repeats=3, random_state=42)

ppv_arr = list()
pr_auc_arr = list()
base_tpr = np.linspace(0, 1, 101)

for train_index, test_index in rskf.split(X, y):
    
    X_train, X_test = X.iloc[train_index], X.iloc[test_index]
    y_train, y_test = y.iloc[train_index], y.iloc[test_index]
    
    calibrated_clf.fit(X_train, y_train)
    
    y_hat = calibrated_clf.predict_proba(X_test)
    ppv, tpr, _ = precision_recall_curve(y_test, y_hat[:, 1], pos_label='positive')
    
    pr_auc = average_precision_score(y_test, y_hat[:, 1], pos_label='positive')
    pr_auc_arr.append(pr_auc)
    
    plt.plot(ppv, tpr, color='r', alpha=0.15)
    ppv = np.interp(base_tpr, ppv, tpr)
    ppv[0] = 0.0
    ppv_arr.append(ppv)

ppv_arr = np.array(ppv_arr)
mean_ppv = ppv_arr.mean(axis=0)
std = ppv_arr.std(axis=0)

ppv_upper = np.minimum(mean_ppv   std, 1)
ppv_lower = mean_ppv - std

plt.plot(mean_ppv, base_tpr, label=f'AUC: {np.mean(pr_auc_arr):.2f}', color='r')
plt.fill_between(base_tpr, ppv_lower, ppv_upper, color='grey', alpha=0.3)
plt.plot([0, 1], [1, 0], 'b--')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.ylabel('Positive Predictive Value')
plt.xlabel('True Positive Rate')
plt.title('KNN PR-ROC Curve and PR-AUC')
plt.legend(loc='best')
plt.show()

I'm not sure what the issue is. I've looked at the weblink and since it's based on a ROC curve there might be something ROC-related that I accidentally included, but I haven't seen anything. Perhaps this is what a PR-ROC with CV looks like? Or perhaps this particular model (KNN) is just bad and that's why the curve looks so strange.

Edit: response to answer

still strange PR-ROC curve

Edit 2: after changing ppv[0] = 0.0 to ppv[0] = 1.0

better PR-ROC curve

Edit 3: response to edited answer

back to strange PR-ROC curve

CodePudding user response:

plt.plot arguments are x then y arrays, and it looks like you have them switched. Should be

plt.plot(tpr, ppv, color='r', alpha=0.15)

and

plt.plot(base_tpr, mean_ppv, label=f'AUC: {np.mean(pr_auc_arr):.2f}', color='r')

This is also true of the np.interp that the coordinates are switched, so that should be

ppv = np.interp(base_tpr, tpr, ppv)

CodePudding user response:

After being unable to figure out what was causing the issue, I decided to start from scratch again, following the ROC curve example without being as strict (as my PR-ROC curve was behaving too similarly to a ROC curve).

The new, working code is below for reference:

rskf = RepeatedStratifiedKFold(n_splits=5, n_repeats=3, random_state=42)

y_real = list()
y_proba = list()
ppv_arr = list()
tpr_arr = np.linspace(0, 1, 100)
    
for train_index, test_index in rskf.split(X, y):
    
    X_train, X_test = X.iloc[train_index], X.iloc[test_index]
    y_train, y_test = y.iloc[train_index], y.iloc[test_index]
    
    calibrated_clf.fit(X_train, y_train)
    
    y_hat = calibrated_clf.predict_proba(X_test)
    
    ppv, tpr, _ = precision_recall_curve(y_test, y_hat[:, 1], pos_label='positive')
    ppv, tpr = ppv[::-1], tpr[::-1]
    
    precision_arr = np.interp(tpr_arr, tpr, ppv)
    pr_auc = average_precision_score(y_test, y_hat[:, 1], pos_label='positive')
    ppv_arr.append(precision_arr)
    
    plt.subplot(222)
    plt.plot(tpr, ppv, color='r', alpha=0.15)
    
    y_real.append(y_test)
    y_proba.append(y_hat[:, 1])

y_real = np.concatenate(y_real)
y_proba = np.concatenate(y_proba)

ppv, tpr, _ = precision_recall_curve(y_real, y_proba, pos_label='positive')

average_ppv = average_precision_score(y_real, y_proba, pos_label='positive')
mean_ppv = np.mean(ppv_arr, axis=0)
std_ppv = np.std(ppv_arr, axis=0)

plt.subplot(222)
plt.plot(tpr, ppv, color='r', label=f'AUC: {average_ppv:.4f}')
plt.fill_between(tpr_arr, mean_ppv   std_ppv, mean_ppv - std_ppv, alpha=0.3, linewidth=0, color='grey')
plt.plot([0, 1], [1, 0], 'b--')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.ylabel('Positive Predictive Value')
plt.xlabel('True Positive Rate')
plt.title('KNN PR-ROC Curve and PR-AUC')
plt.legend(loc='best')
plt.show()

The PR-ROC curve looks like so:

working PR-ROC curve

  • Related