I am having a bit of trouble trying to use DataGenerator class with my model training during the model.fit()
function. The reason I am using the DataGenerator is to help with dealing with my large amount of images and labels for my object detection task.
All of my images are split into train, test and validation sets. I have converted the images and my labels to a numpy array and then preprocessed accordingly for my RESNET50 model, which works fine.
val_images = np.array(val_images)
train_images = np.array(train_images)
However, when I try to use the DataGenerator function for my val and training images, it does not seem to work.
training_generator = DataGenerator(train_images, train_targets)
validation_generator = DataGenerator(val_images, val_targets)
I tried to use the model.fit() function but it does not work and ends up showing an error.
resnet_model = model.fit_generator(
training_generator,
epochs=4,
validation_data=validation_generator)
TypeError: expected str, bytes or os.PathLike object, not ndarray
Full traceback: https://www.toptal.com/developers/hastebin/gusicucali.yaml
I am not sure if this is the right way to use the DataGenerator class as I am quite new with it. I have 6000 images with labels (xmin,ymin,ymax,xmax) accordingly. The reason why I am trying to use it is to help make my model training a bit more efficient since I do not have a GPU.
CodePudding user response:
I think you should aim at having a code similar to this below. I used ImageDataGenerator
. If you have a look at the documentation it lets you pick lots of arguments to get a data augmentation of your images. I copied only a few.
val_images = np.array(val_images)
train_images = np.array(train_images)
dataGenerator = ImageDataGenerator(rotation_range=0,
zoom_range=0,
width_shift_range=0,
height_shift_range=0,
horizontal_flip=False,
fill_mode="nearest")
# train
model = model.fit_generator(dataGenerator.flow(train_images, train_targets, batch_size=BATCH_SIZE),
validation_data=(val_images, val_targets),
steps_per_epoch=len(train_images),
epochs=EPOCHS)
This should work. Also I think that recently fit_generator
has been deprecated, but you could substitute it with fit
without changing anything.