Home > Blockchain >  using ImageDataGenerator, matplotlib throws TypeError: Invalid shape (1, 256, 256, 3) for image data
using ImageDataGenerator, matplotlib throws TypeError: Invalid shape (1, 256, 256, 3) for image data

Time:10-11

I have 15 images of cars, and using data augmentation, I want to create a dataset out of them. However, when I use ImageDataGenerator from Keras and try to plot the generated images, I'm getting an error that says

TypeError: Invalid shape (1, 256, 256, 3) for image data.

I am attaching the code as well, please let me know how I can fix this.

datagen = ImageDataGenerator(rescale=1./255, zoom_range=0.1, rotation_range=25, width_shift_range=0.1, height_shift_range=0.1, shear_range=0.1, horizontal_flip=True)

ite = datagen.flow_from_directory("Car Images", batch_size=1)

for i in range(9):

    # define subplot
    plt.subplot(330   1   i)

    # generate batch of images
    batch = ite.next()

    # convert to unsigned integers for viewing
    image = batch[0].astype('uint8')

    # plot raw pixel data
    plt.imshow(image)

# show the figure
plt.show()

The error points to the plt.imshow() line.

This is showing when I use np.squeeze() or np.reshape()

CodePudding user response:

You need to reshape the image, try this.

datagen = ImageDataGenerator(rescale=1./255, zoom_range=0.1, rotation_range=25, width_shift_range=0.1, height_shift_range=0.1, shear_range=0.1, horizontal_flip=True)

ite = datagen.flow_from_directory("Car Images", batch_size=1)

for i in range(9):

    # define subplot
    plt.subplot(330   1   i)

    # generate batch of images
    batch = ite.next()

    # convert to unsigned integers for viewing
    image = batch[0].astype('uint8')

    image = np.reshape(256,256,3)

    # plot raw pixel data
    plt.imshow(image)

# show the figure
plt.show()

CodePudding user response:

For Invalid Shape error, you should remove batch dimension. So, reshaping or using np.squeeze() is required.

Furthermore, since you are rescaling your images by 1./255, images data are in range [0,1], and converting them to uint8 makes all of them zero. So change your last 2 lines in for loop like this:

image = batch[0]                  #remove astype('uint8')
# plot raw pixel data
plt.imshow(np.squeeze(image))     #remove batch dimension
  • Related