Home > Mobile >  predict all test batches in keras / tensorflow
predict all test batches in keras / tensorflow

Time:04-25

i am trying to predict all my test batches in keras / tensorflow to then plot a confusion matrix. The current BATCH_SIZE is: 32

My test dataset is generated with the following code from a big dataset:

test_dataset = big_dataset.skip(train_size).take(test_size)
test_dataset = test_dataset.shuffle(test_size).map(augment).batch(BATCH_SIZE)

After model.compile() and model.fit() i get the predictions and the correct labels with this code:

points, labels = list(test_dataset)[0]
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
points = points.numpy()

This method only predicts one batch --> 32 predictions.

Is there a way to predict all test batches in keras / tensorflow?

Thanks in advance!

CodePudding user response:

You could pass the entire dataset to model.predict according to the docs:

Input samples. It could be: A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs). A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). A tf.data dataset. A generator or keras.utils.Sequence instance. A more detailed description of unpacking behavior for iterator types (Dataset, generator, Sequence) is given in the Unpacking behavior for iterator-like inputs section of Model.fit.

points = test_dataset.map(lambda x, y: x)
labels = test_dataset.map(lambda x, y: y)
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)

Or with numpy:

points = np.concatenate(list(test_dataset.map(lambda x, y: x))
labels = np.concatenate(list(test_dataset.map(lambda x, y: y))
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
  • Related