I've been working on a multiclass classification problem, where I need to make a function to show an image of a certain class of the fashion MNIST dataset and make a prediction on it. For example, plot 3 images of the T-shirt
class with their predictions. I have tried different things but no success yet. I'm missing a conditional statement and I can't figure out how & where to implement it in my function.
This is what I've come up with so far:
# Make function to plot image
def plot_image(indx, predictions, true_labels, target_images):
"""
Picks an image, plots it and labels it with a predicted and truth label.
Args:
indx: index number to find the image and its true label.
predictions: model predictions on test data (each array is a predicted probability of values between 0 to 1).
true_labels: array of ground truth labels for images.
target_images: images from the test data (in tensor form).
Returns:
A plot of an image from `target_images` with a predicted class label
as well as the truth class label from `true_labels`.
"""
# Set target image
target_image = target_images[indx]
# Truth label
true_label = true_labels[indx]
# Predicted label
predicted_label = np.argmax(predictions) # find the index of max value
# Show image
plt.imshow(target_image, cmap=plt.cm.binary)
plt.xticks([])
plt.yticks([])
# Set colors for right or wrong predictions
if predicted_label == true_label:
color = 'green'
else:
color = 'red'
# Labels appear on the x-axis along with accuracy %
plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
100*np.max(predictions),
class_names[true_label]),
color=color)
# Function to display image of a class
def display_image(class_indx):
# Set figure size
plt.figure(figsize=(10,10))
# Set class index
class_indx = class_indx
# Display 3 images
for i in range(3):
plt.subplot(1, 3, i 1)
# plot_image function
plot_image(indx=class_indx, predictions=y_probs[class_indx],
true_labels=test_labels, target_images=test_images_norm)
These are the class names 'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
. When I call the display function display_image()
and pass the class index as an argument display_image(class_indx=15)
, I'm getting the same image and the same prediction 3 times (Notice my wrong approach, I'm passing index number instead it should be the class name). I need a function that takes a str
(the class name) and displays 3 different predictions of that class. For instance, display_image('Dress')
should return 3 images of Dress
class along with its 3 different predictions that my model has made, Prediction#1 (65%)
, Prediction#2 (100%)
, Prediction#3 (87%)
like so. Thanks!
CodePudding user response:
I think you are really close into solving your problem. You just need to sample three samples from your category of interest. I guess that you have used a le = LabelEncoder()
to encode your target vector. If yes, then you will have the classes like this: labels = list(le.classes_)
. Then I would do the following:
def display_image(class_of_interest: str, nb_samples: int=3):
plt.figure(figsize=(10,10))
class_indx = class_names.index(class_of_interest)
target_idx = np.where(true_labels==class_indx)[0]
imgs_idx = np.random.choice(target_idx, nb_samples, replace=False)
for i in range(nb_samples):
plt.subplot(1, nb_samples, i 1)
plot_image(indx=imgs_idx[i],
predictions=y_probs[imgs_idx[i]],
true_labels=test_labels,
target_images=test_images_norm)