The code below opens the mnist dataset as a csv
import numpy as np
import csv
import matplotlib.pyplot as plt
with open('C:/Z_Uni/Individual_Project/Python_Projects/NeuralNet/MNIST_Dataset/mnist_train.csv/mnist_train.csv', 'r') as csv_file:
for data in csv.reader(csv_file):
# The first column is the label
label = data[0]
# The rest of columns are pixels
pixels = data[1:]
# Make those columns into a array of 8-bits pixels
# This array will be of 1D with length 784
# The pixel intensity values are integers from 0 to 255
pixels = np.array(pixels, dtype='uint8')
print(pixels.shape)
# Reshape the array into 28 x 28 array (2-dimensional array)
pixels = pixels.reshape((28, 28))
print(pixels.shape)
# Plot
plt.title('Label is {label}'.format(label=label))
plt.imshow(pixels, cmap='gray')
plt.show()
break # This stops the loop, I just want to see one
I got the code above from someone and cannot get it to display the mnist digits.
I get the error:
Traceback (most recent call last): File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\Test_View_Mnist.py", line 16, in pixels = np.array(pixels, dtype='uint8') ValueError: invalid literal for int() with base 10: '1x1'
When I remove dtype='unit8' I get the error:
Traceback (most recent call last): File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\Test_View_Mnist.py", line 24, in plt.imshow(pixels, cmap='gray') File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib_api\deprecation.py", line 456, in wrapper return func(*args, **kwargs) File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib\pyplot.py", line 2640, in imshow _ret = gca().imshow( File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib_api\deprecation.py", line 456, in wrapper return func(*args, **kwargs) File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib_init.py", line 1412, in inner return func(ax, *map(sanitize_sequence, args), **kwargs) File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib\axes_axes.py", line 5488, in imshow
im.set_data(X)
File "C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib\image.py", line 706, in set_data raise TypeError("Image data of dtype {} cannot be converted to " TypeError: Image data of dtype <U5 cannot be converted to float
Process finished with exit code 1
Could someone explain why this error is happening and how to fix it? Thanks.
CodePudding user response:
There are two problems here. (1) You need to skip the first row because they are labels. (1x1), (1x2) and etc. (2) You need int64 data type. The code below will solve both. next(csvreader) skips the first row.
import numpy as np
import csv
import matplotlib.pyplot as plt
with open('./mnist_test.csv', 'r') as csv_file:
csvreader = csv.reader(csv_file)
next(csvreader)
for data in csvreader:
# The first column is the label
label = data[0]
# The rest of columns are pixels
pixels = data[1:]
# Make those columns into a array of 8-bits pixels
# This array will be of 1D with length 784
# The pixel intensity values are integers from 0 to 255
pixels = np.array(pixels, dtype = 'int64')
print(pixels.shape)
# Reshape the array into 28 x 28 array (2-dimensional array)
pixels = pixels.reshape((28, 28))
print(pixels.shape)
# Plot
plt.title('Label is {label}'.format(label=label))
plt.imshow(pixels, cmap='gray')
plt.show()