How to filter tensorflow dataset by value of single feature?
I spent a lot of time to understand how to filter tensorflow datasets with filter method, unfortunatelly the documentation is not clear enough for me https://www.tensorflow.org/api_docs/python/tf/data/Dataset#filter, maybe it will be useful for someone.
In example below the goal is: Select samples if for the feature name 'Status" value is equal 'success' and for feature name 'Cost' value is >0.
dataset = tf.data.experimental.make_csv_dataset('file1.csv',....)
dataset = dataset.unbatch().filter(lambda x, y: True if x["Status"] == 'success' else False)
dataset = dataset.filter(lambda x, y: True if x["Cost"] > 0.0 else False)
CodePudding user response:
I hope this example will be usefull.
CodePudding user response:
You can try something like this:
import tensorflow as tf
import pandas as pd
df = pd.DataFrame(data={'Status': ['Success', 'Failure','Failure', 'Success'], 'Cost': [0.0, 1.0, 1.0, 2.0]})
df.to_csv('data.csv', index=False)
dataset = tf.data.experimental.make_csv_dataset('/content/data.csv', batch_size=2, num_epochs = 1)
dataset = dataset.unbatch().filter(lambda x: x["Status"] == 'Success' and x["Cost"] > 0.0)
for x in dataset:
print(x['Status'], x['Cost'])
tf.Tensor(b'Success', shape=(), dtype=string) tf.Tensor(2.0, shape=(), dtype=float32)