Home > Blockchain >  Proper way of resizing image for Deep Learning models
Proper way of resizing image for Deep Learning models

Time:04-12

I'm a beginner in Deep Learning & Tensorflow. During the preprocessing part, I'm stucking again & again on that part where I have to resize the image with specific dimension for some specific NN architecture. I googled and tried different methods but in vain.

For eg., I did following to resize image to 227 x 227 for AlexNet:

height = 227
width = 227
dim = (width, height)

x_train = np.array([cv2.resize(img, dim) for img in x_train[:,:,:]])
x_valid = np.array([cv2.resize(img, dim) for img in x_valid[:,:,:]])

x_train = tf.expand_dims(x_train, axis=-1)
x_valid = tf.expand_dims(x_valid, axis=-1)

I'm trying to resize the images with cv2 but after expanding, the dimensions come out to be:

(227, 227, 1)

whereas I want them to be:

(227, 227, 3)

So, is there any better way to do this?

CodePudding user response:

The following line in your script is causing the problem

x_train = np.array([cv2.resize(img, dim) for img in x_train[:,:,:]])

Change it to

x_train = np.array([cv2.resize(img, dim) for img in x_train])

CodePudding user response:

One option for fasting do this can be creating a dataset with tf.data.Dataset then writing a function for resizing images with tf.image.resize like below:

import tensorflow as tf
import matplotlib.pyplot as plt

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
test_dataset  = tf.data.Dataset.from_tensor_slices((X_test, y_test))


HEIGHT = 227
WIDTH = 227

def resize_preprocess(image, label):
    image = tf.image.resize(image, (HEIGHT, WIDTH)) / 255.0
    return image, label


train_dataset = train_dataset.map(resize_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset  = test_dataset.map(resize_preprocess, num_parallel_calls=tf.data.AUTOTUNE)


for image, label in train_dataset.take(1):
    print(image.shape)
    plt.imshow(image), plt.axis('off')
    plt.show()

Output:

(227, 227, 3)

enter image description here

  • Related