i am trying to predict all my test batches in keras / tensorflow to then plot a confusion matrix.
The current BATCH_SIZE
is: 32
My test dataset is generated with the following code from a big dataset:
test_dataset = big_dataset.skip(train_size).take(test_size)
test_dataset = test_dataset.shuffle(test_size).map(augment).batch(BATCH_SIZE)
After model.compile()
and model.fit()
i get the predictions and the correct labels with this code:
points, labels = list(test_dataset)[0]
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
points = points.numpy()
This method only predicts one batch --> 32 predictions.
Is there a way to predict all test batches in keras / tensorflow?
Thanks in advance!
CodePudding user response:
You could pass the entire dataset to model.predict
according to the docs:
Input samples. It could be: A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs). A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). A tf.data dataset. A generator or keras.utils.Sequence instance. A more detailed description of unpacking behavior for iterator types (Dataset, generator, Sequence) is given in the Unpacking behavior for iterator-like inputs section of Model.fit.
points = test_dataset.map(lambda x, y: x)
labels = test_dataset.map(lambda x, y: y)
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)
Or with numpy
:
points = np.concatenate(list(test_dataset.map(lambda x, y: x))
labels = np.concatenate(list(test_dataset.map(lambda x, y: y))
preds = model.predict(points)
preds = tf.math.argmax(preds, -1)