I'm converting a CSV file into a TFRecords file like this:
File: ./dataset/csv/file.csv
feature_1, feture_2, output
1, 1, 1
2, 2, 2
3, 3, 3
import tensorflow as tf
import csv
import os
print(tf.__version__)
def create_csv_iterator(csv_file_path, skip_header):
with tf.io.gfile.GFile(csv_file_path) as csv_file:
reader = csv.reader(csv_file)
if skip_header: # Skip the header
next(reader)
for row in reader:
yield row
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def create_example(row):
"""
Returns a tensorflow.Example Protocol Buffer object.
"""
features = {}
for feature_index, feature_name in enumerate(["feature_1", "feture_2", "output"]):
feature_value = row[feature_index]
features[feature_name] = _int64_feature(int(feature_value))
return tf.train.Example(features=tf.train.Features(feature=features))
def create_tfrecords_file(input_csv_file):
"""
Creates a TFRecords file for the given input data
"""
output_tfrecord_file = input_csv_file.replace("csv", "tfrecords")
writer = tf.io.TFRecordWriter(output_tfrecord_file)
print("Creating TFRecords file at", output_tfrecord_file, "...")
for i, row in enumerate(create_csv_iterator(input_csv_file, skip_header=True)):
if len(row) == 0:
continue
example = create_example(row)
content = example.SerializeToString()
writer.write(content)
writer.close()
print("Finish Writing", output_tfrecord_file)
create_tfrecords_file("./dataset/csv/file.csv")
Then I'll read the generated TFRecords files using ImportExampleGen
class:
import os
import absl
import tensorflow_model_analysis as tfma
tf.get_logger().propagate = False
from tfx import v1 as tfx
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
%load_ext tfx.orchestration.experimental.interactive.notebook_extensions.skip
context = InteractiveContext()
example_gen = tfx.components.ImportExampleGen(input_base="./dataset/tfrecords")
context.run(example_gen, enable_cache=True)
statistics_gen = tfx.components.StatisticsGen(
examples=example_gen.outputs['examples'])
context.run(statistics_gen, enable_cache=True)
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=False)
context.run(schema_gen, enable_cache=True)
File: ./transform.py
def preprocessing_fn(inputs):
"""tf.transform's callback function for preprocessing inputs.
Args:
inputs: map from feature keys to raw not-yet-transformed features.
Returns:
Map from string feature key to transformed feature operations.
"""
print(inputs)
return inputs
transform = tfx.components.Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=os.path.abspath("./transform.py"))
context.run(transform, enable_cache=True)
In the preprocessing_fn
function shows that inputs is a SparseTensor objects. My question is why? As far as I can tell, my dataset's samples are dense and they should be Tensor instead. Am I doing something wrong?
CodePudding user response:
For anyone else who might be struggling with the same issue, I found the culprit. It's the SchemaGen
class. This is how I was instantiating its object:
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=False)
I don't know what's the use case for asking SchemaGen
class not to infer the shape of the features but the tutorial I was following had it set to False
and I had just copied and pasted the same thing. Comparing with some other tutorials, I realized that it could be the reason why I was getting SparseTensor
.
So, if you let SchemaGen
infer the shape of your features or you load a hand crafted schema in which you've set the shapes yourself, you'll be getting a Tensor
in your preprocessing_fn
. But if the shapes are not set, the features will be instances of SparseTensor
.
For the sake of completeness, this is the fixed snippet:
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=True)