Home > Enterprise >  Confusion Matrix Detailed Explanation of Dataset
Confusion Matrix Detailed Explanation of Dataset

Time:06-26

df = pd.concat(map(pd.read_csv, ['A.csv','B1.csv','B2.csv','B3.csv', 'C1.csv', 'C2.csv','C3.csv']))

As shown above, I used multiple datasets to train and test classification supervised ML(DT, RF, SVM, KNN, MLP) models. I used a confusion matrix and classification report to evaluate the models.

classes = ['Good','Bad']
def plot_confusionmatrix(prediction,real,dom):
    print(f'{dom} Confusion matrix')
    cf = confusion_matrix(prediction,real)
    sns.heatmap(cf,annot=True,yticklabels=classes
               ,xticklabels=classes,cmap='Blues', fmt='g')
    plt.tight_layout()
    plt.show()  

........... TRAINING THE MODEL WITH ONE OF THE ALGORITHM

print(f'Test score {accuracy_score(y_test_DT_predicted,y_test)}')
print("-------------")
plot_confusionmatrix(y_test_DT_predicted,y_test,dom='Test')
print('Classification Report\n\n', classification_report(y_test, y_test_DT_pred))

enter image description here

I wanted to ask if instead of showing the confusion matrix among the entire testing set, how can I focus only on those samples coming from the A dataset? I am interested in seeing how these samples are classified. ( **** number of samples from A, **** number of samples from B1 ..)

Edit: As suggested, I added a column for each dataset and then I merged them. I will drop that column and train the model. But how to show it in the confusion matrix still?

A['Dataset'] = "A"
B1['Dataset'] = "B1"
B2['Dataset'] = "B2"
B3['Dataset'] = "B3"
C1['Dataset'] = "C1"
C2['Dataset'] = "C2"
C3['Dataset'] = "C3"

X = df.drop(['Class', 'Dataset'], axis=1)
y = df['Class']

CodePudding user response:

Seems like there is no other way except plotting confusion matrix for each subset separately. It must be easy to do in your case by iterating through all dataset names, selecting corresponding rows from y_test_DT_predicted and y_test for each dataset and using selected values in plot_confusionmatrix call. Don't forget to manage multiple subplots in plt and call plt.show() once at the end of the loop.

  • Related