I have a dataset with images loaded with image_dataset_from_directory
with 4 classes containing tuples of (image, label).
What i did was create a column 'id'
with the same size as the dataset, convert this column to a tf data dataset and concatenate the 2 datasets using :
dataset = tf.data.Dataset.zip((dataset, client_id))
Resulting to a dataset with signature :
<ZipDataset element_spec=((TensorSpec(shape=(128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(), dtype=tf.int32, name=None)), TensorSpec(shape=(), dtype=tf.int64, name=None))>
What i would like to do now is be able to filter this concatenated dataset whenever i want based on the client id value. What i tried is :
dataset = dataset.filter( x : x[1]==15)
but i get :
TypeError: 'ZipDataset' object is not subscriptable
However this :
for x in dataset.take(1):
print(x[1])
prints the client id correctly :
tf.Tensor(15, shape=(), dtype=int64)
How could this be done?
CodePudding user response:
x[0]
is a tuple and x[1]
contains the client-ids when iterating over your dataset. Maybe try:
import tensorflow as tf
a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 7) # ==> [ 4, 5, 6 ]
dataset = tf.data.Dataset.zip((tf.data.Dataset.zip((a, b)), tf.data.Dataset.range(4, 7)))
dataset = dataset.filter(lambda x, y: y==4)
for x, y in dataset:
print(x, y)
(<tf.Tensor: shape=(), dtype=int64, numpy=1>, <tf.Tensor: shape=(), dtype=int64, numpy=4>) tf.Tensor(4, shape=(), dtype=int64)
Note that y
refers to the cliend ids in this case.