Home > OS >  In image classification how to count correct label? [closed]
In image classification how to count correct label? [closed]

Time:09-29

I have a basic classification problem; 2 classes to classify 0 or 1. I have trained my model using Keras VGG16 but how to count the number of correct label? or the number of all images under label 0? Assuming my model should distinguish between cats and dogs, how to count how many correctly labeled cats are there in validation set predicted by the model? I do have the accuracy but I need to access the count number.

This is my model:

#base_mode= VGG16()

model = tf.keras.models.Sequential()
#model.add(base_model)
model.add(tf.keras.layers.Convolution2D(16,4,3, input_shape= (32,32,3),activation = 'relu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size = (2,2)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(32, activation = 'relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
model.compile(optimizer="Adam", loss="binary_crossentropy" , metrics=['accuracy'])

epochs = 50
history = model.fit(x=train_ds, y=train_lb , epochs=epochs, validation_data= (test_ds, test_lb) )

train_lb and test_lb have corresponding labels for the training/testing dataset, it's an array containing either 0 or 1. For example, the label for train_ds[0] would be train_lb[0]

CodePudding user response:

If you want to have all this info at once, I think the easier way would be to do a confusion matrix, that will show you how everythin is classified:

plt.figure(figsize=(8,6))
plt.title("Confusion matrix")
cf_matrix = metrics.confusion_matrix(y_test, y_pred)
group_names = ["True Neg", "False Pos", "False Neg", "True Pos"]
group_counts = ["{0:0.0f}".format(value) for value in
                cf_matrix.flatten()]
group_percentages = ["{0:.2%}".format(value) for value in
                     cf_matrix.flatten()/np.sum(cf_matrix)]
labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in zip(group_names,group_counts,group_percentages)]
labels = np.asarray(labels).reshape(2,2)
print(sns.heatmap(cf_matrix, annot=labels, fmt="", cmap='Blues'))
  • Related