Sampling random mnist dataset , but it shows error, tell me where i am doing wrong?
import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')
y = mnist.target
X = mnist.data
fig, ax = plt.subplots(2, 5)
ax = ax.flatten()
for i in range(10):
im_idx = np.argwhere(y == i)[0]
print(im_idx)
plottable_image = np.reshape(X[im_idx], (28, 28))
ax[i].imshow(plottable_image, cmap='gray_r')
CodePudding user response:
The error you are getting results because np.argwhere(y == i)
returns an empty array and that is because you are trying to make a comparison between y
which is filled with string
values and i
which is an int
.
The following change will fix it:
im_idx = np.argwhere(y == str(i))[0]