Home > Enterprise >  How to filter tensorflow dataset by value of single feature
How to filter tensorflow dataset by value of single feature

Time:02-17

How to filter tensorflow dataset by value of single feature?

I spent a lot of time to understand how to filter tensorflow datasets with filter method, unfortunatelly the documentation is not clear enough for me https://www.tensorflow.org/api_docs/python/tf/data/Dataset#filter, maybe it will be useful for someone.

In example below the goal is: Select samples if for the feature name 'Status" value is equal 'success' and for feature name 'Cost' value is >0.

dataset = tf.data.experimental.make_csv_dataset('file1.csv',....)
dataset = dataset.unbatch().filter(lambda x, y: True if x["Status"] == 'success' else False)
dataset = dataset.filter(lambda x, y: True if x["Cost"] > 0.0 else False)

CodePudding user response:

I hope this example will be usefull.

CodePudding user response:

You can try something like this:

import tensorflow as tf
import pandas as pd

df = pd.DataFrame(data={'Status': ['Success', 'Failure','Failure', 'Success'], 'Cost': [0.0, 1.0, 1.0, 2.0]})
df.to_csv('data.csv', index=False)

dataset = tf.data.experimental.make_csv_dataset('/content/data.csv', batch_size=2, num_epochs = 1)
dataset = dataset.unbatch().filter(lambda x: x["Status"] == 'Success' and x["Cost"] > 0.0)

for x in dataset:
  print(x['Status'], x['Cost'])
tf.Tensor(b'Success', shape=(), dtype=string) tf.Tensor(2.0, shape=(), dtype=float32)
  • Related