Sampling random mnist dataset , but it shows error, tell me where i am doing wrong?
import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')
y = mnist.target
X = mnist.data.astype('float64')
fig, ax = plt.subplots(2, 5)
ax = ax.flatten()
for i in range(10):
im_idx = np.argwhere(y == str(i))[0] #panda versions bug, worked in pycharm. {kindly tell me an alternate way to fix it}
print(im_idx)
plottable_image = np.reshape(X[im_idx], (28, 28))
ax[i].imshow(plottable_image, cmap='gray_r')
the error shows
ValueError: Length of passed values is 1, index implies 70000.
CodePudding user response:
The error you are getting results because np.argwhere(y == i)
returns an empty array and that is because you are trying to make a comparison between y
which is filled with string
values and i
which is an int
.
The following change will fix it:
im_idx = np.argwhere(y == str(i))[0]
Edit:
Here is the updated cast to match the numpy
operations:
y = mnist.target.to_numpy()
X = mnist.data.astype('float64').to_numpy()