Home > Software design >  filter dataset by label in tensorflow
filter dataset by label in tensorflow


I'm new to tensorflow (and python in general) and I'm having hard time wrapping my head around so features of tensors. I am usings tf.keras.utils.image_dataset_from_directory() to get a dataset of images and labels(classes). I want to filter the imgaes by the class, using filter(). Something like,

full_ds = tf.keras.utils.image_dataset_from_directory(
fibrosis_ds = full_ds.filter(lambda x, y:  y==0 ) # y == 0 for fibrosis

This give the error

ValueError: Invalid predicate. predicate must return a tf.bool scalar tensor, but its return type is NoneTensorSpec().

If I print y in the lambda the output is

Tensor("args_1:0", shape=(None,), dtype=int32)

And if I print in a loop

for x, y in full_ds:

the output is

tf.Tensor([1 1 1 1 0 1 1 1 0 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 0 1 1 1 1 1], shape=(32,), dtype=int32)

Which make sense because image_dataset_from_directory()'s default is 32. 0's in this array represent Fibrosis and 1's are a different calss (Normals).

How do I get the lambda to work with filter().

CodePudding user response:

The problem seems to be performing filters on batches. Either unbatch and use tf.data.Dataset.filter:

fibrosis_ds = full_ds.unbatch().filter(lambda x, y:  tf.equal(y, 0) ).batch(32) # y == 0 for fibrosis

Or just use tf.data.Dataset.map (preferable):

fibrosis_ds = full_ds.map(lambda x, y:  (x[y==0], y[y==0]))
# or
fibrosis_ds = full_ds.map(lambda x, y:  (tf.boolean_mask(x, y==0), tf.boolean_mask(y, y==0)))
  • Related