Home > Software design >  filter dataset by label in tensorflow
filter dataset by label in tensorflow

Time:08-28

I'm new to tensorflow (and python in general) and I'm having hard time wrapping my head around so features of tensors. I am usings tf.keras.utils.image_dataset_from_directory() to get a dataset of images and labels(classes). I want to filter the imgaes by the class, using filter(). Something like,

full_ds = tf.keras.utils.image_dataset_from_directory(
    'the_path',
    image_size=(SIZE,SIZE),
)
fibrosis_ds = full_ds.filter(lambda x, y:  y==0 ) # y == 0 for fibrosis

This give the error

ValueError: Invalid predicate. predicate must return a tf.bool scalar tensor, but its return type is NoneTensorSpec().

If I print y in the lambda the output is

Tensor("args_1:0", shape=(None,), dtype=int32)

And if I print in a loop

for x, y in full_ds:
    print(y)
    break

the output is

tf.Tensor([1 1 1 1 0 1 1 1 0 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 0 1 1 1 1 1], shape=(32,), dtype=int32)

Which make sense because image_dataset_from_directory()'s default is 32. 0's in this array represent Fibrosis and 1's are a different calss (Normals).

How do I get the lambda to work with filter().

CodePudding user response:

The problem seems to be performing filters on batches. Either unbatch and use tf.data.Dataset.filter:

fibrosis_ds = full_ds.unbatch().filter(lambda x, y:  tf.equal(y, 0) ).batch(32) # y == 0 for fibrosis

Or just use tf.data.Dataset.map (preferable):

fibrosis_ds = full_ds.map(lambda x, y:  (x[y==0], y[y==0]))
# or
fibrosis_ds = full_ds.map(lambda x, y:  (tf.boolean_mask(x, y==0), tf.boolean_mask(y, y==0)))
  • Related