Home > OS >  SKlearn dataset , Mnist random sample plotting is not working? cannot find error
SKlearn dataset , Mnist random sample plotting is not working? cannot find error

Time:11-16

Sampling random mnist dataset , but it shows error, tell me where i am doing wrong?

import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml

mnist = fetch_openml('mnist_784')
y = mnist.target
X = mnist.data

fig, ax = plt.subplots(2, 5)
ax = ax.flatten()
for i in range(10):
    im_idx = np.argwhere(y == i)[0]
    print(im_idx)
    plottable_image = np.reshape(X[im_idx], (28, 28))
    ax[i].imshow(plottable_image, cmap='gray_r')
    

CodePudding user response:

The error you are getting results because np.argwhere(y == i) returns an empty array and that is because you are trying to make a comparison between y which is filled with string values and i which is an int.

The following change will fix it:

im_idx = np.argwhere(y == str(i))[0]
  • Related