I am thinking to find the corresponding confusion matrix (binary classification) given some metrics (such as accuracy, sensitivity, precision, f1-score).
I know it is easy to find the confusion matrix from ground-truth and predicted labels. However, in my case, I don't have the predicted labels.
For example, I have
acc, sen, pre, f1_score = 68.00, 51.28, 80.00, 62.50
Is there a way to find the respective confusion matrix?
CodePudding user response:
Yes you can recover the confusion matrix (for binary classification), but you need to know the original number of positive (P
) and negative (N
) examples.
For example, if you know you have N = 6
and P = 4
, and know the sensitivity is sen = 0.75
then you can plug them into the equation for sensitivity to get the number of true positives (TP
):
sen = TP / P
0.75 = TP / 4
TP = 3
Now if you know the accuracy is acc = 0.699
, you can solve for the number of true negatives (TN
):
acc = (TP TN) / (P N)
0.699 = (3 TN) / (4 6)
TN = 4
Which gives you enough to reconstruct the confusion matrix:
TN (N - TN)
(P - TP) TP
Here is an example to help demonstrate this:
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, confusion_matrix
y_true = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]) # N = 6, P = 4
y_pred = np.array([1, 1, 0, 0, 0, 0, 0, 1, 1, 1])
accuracy_score(y_true, y_pred)
# 0.69999
recall_score(y_true, y_pred)
# 0.75 i.e., recall is sensitivity
confusion_matrix(y_true, y_pred)
# array([[4, 2],
# [1, 3]])
CodePudding user response:
Thank you so much @Alexander L. Hayes, that is much easier than Grid-Search,
gTruth = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]
N = len(np.where(np.array(gTruth)==0)[0])
P = len(np.where(np.array(gTruth)==1)[0])
conf_mat = np.zeros((2,2))
acc_1 = 0.699
sen_1 = 0.75
TP = round(sen_1*P)
TN = round((acc_1*(N P))-TP)
FP = N-TN
FN = P-TP
conf_mat[0,0] = TN
conf_mat[0,1] = FP
conf_mat[1,0] = FN
conf_mat[1,1] = TP
acc = ((TP TN)/(TP TN FP FN))
sen = (TP/(TP FN))
spe = (TN/(TN FP))
try:
pre = (TP/(TP FP))
except:
pre = 0.0
try:
f1 = (2 * (pre * sen) / (pre sen))
except:
f1 = 0.0
print(conf_mat)