Home > other >  How to get the correct cardinality of a Tensorflow dataset after filtering
How to get the correct cardinality of a Tensorflow dataset after filtering

Time:02-05

I create a TensorFlow dataset with elements from 0 to 49, then filter it by only keep elements less than 25, as follows

import tensorflow as tf
dataset = tf.data.Dataset.range(50) 
dataset_less_25 = dataset.filter(lambda x: x < 25)

However, when I check the cardinality of the new dataset as follows:

dataset_less_25.cardinality().numpy() 

It returns -2, which does not make sense. I further checked that the new dataset actually contains 25 elements, so I wonder why cardinality() function does not work in this case?

CodePudding user response:

Checking the docs of this method, there are special integer codes for infinite as well as unknown cardinalities. Way at the bottom, wee see that -2 codes for unknown cardinality. That is, the method was not able to infer the dataset size. Actually, filter is used as an example for a dataset with unknown cardinality.

Why this is the case, I'm not sure. Digging in the code, the implementation for cardinality() is here. This leads to gen_dataset_ops.dataset_cardinality. However I cannot find gen_dataset_ops in the codebase. It might be a file that is automatically generated from somewhere else.

I would assume that this method only performs a very rudimentary analysis (e.g. for a range Dataset it is very easy to say how many elements there are) without actually evaluating any of the dataset elements, and if this simple method cannot succeed (as it's not clear which elements will pass the filter without actually looking at the elements), it returns "unknown".

  •  Tags:  
  • Related