Home > database >  How to tf.cast a field within a tensorflow Dataset
How to tf.cast a field within a tensorflow Dataset

Time:05-27

I have a tf.data.Dataset that looks like this:

<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>

The 2nd element (1st if zero indexing) corresponds with a label. I want to cast the 2nd term (labels) to tf.uint8.

How can one use tf.cast when dealing with td.data.Dataset?


Similar Questions

How to convert tf.int64 to tf.float32? is very similar, but is not for a tf.data.Dataset.


Repro

From Image classification from scratch:

curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
unzip kagglecatsanddogs_5340.zip

Then in Python with tensorflow~=2.4:

import tensorflow as tf

ds = tf.keras.preprocessing.image_dataset_from_directory(
    "PetImages", batch_size=32
)
print(ds)

CodePudding user response:

A map function may help

a = tf.data.Dataset.from_tensor_slices(np.empty((2,5,3)))
b = tf.data.Dataset.range(5, 8)
c = tf.data.Dataset.zip((a,b))
d = c.batch(1)
d
<BatchDataset shapes: ((None, 5, 3), (None,)), types: (tf.float64, tf.int64)>

# change the dtype of the 2nd arg in the batch from int64 to int8
e = d.map(lambda x,y:(x,tf.cast(y, tf.int8))) 
<MapDataset shapes: ((None, 5, 3), (None,)), types: (tf.float64, tf.int8)>
  • Related