I have a tf.data.Dataset
that looks like this:
<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>
The 2nd element (1st if zero indexing) corresponds with a label. I want to cast the 2nd term (labels) to tf.uint8
.
How can one use tf.cast
when dealing with td.data.Dataset
?
Similar Questions
How to convert tf.int64 to tf.float32? is very similar, but is not for a tf.data.Dataset
.
Repro
From Image classification from scratch:
curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
unzip kagglecatsanddogs_5340.zip
Then in Python with tensorflow~=2.4
:
import tensorflow as tf
ds = tf.keras.preprocessing.image_dataset_from_directory(
"PetImages", batch_size=32
)
print(ds)
CodePudding user response:
A map
function may help
a = tf.data.Dataset.from_tensor_slices(np.empty((2,5,3)))
b = tf.data.Dataset.range(5, 8)
c = tf.data.Dataset.zip((a,b))
d = c.batch(1)
d
<BatchDataset shapes: ((None, 5, 3), (None,)), types: (tf.float64, tf.int64)>
# change the dtype of the 2nd arg in the batch from int64 to int8
e = d.map(lambda x,y:(x,tf.cast(y, tf.int8)))
<MapDataset shapes: ((None, 5, 3), (None,)), types: (tf.float64, tf.int8)>