Home > Blockchain >  One Hot Encoding in Tensorflow
One Hot Encoding in Tensorflow

Time:10-14

I've been following the tensorflow walkthrough here to create my own categorical OHE layer. The layer suggested is below and I've followed the preceding steps to the guide very closely:

 def get_category_encoding_layer(name, dataset, dtype, max_tokens=None):
  # Create a StringLookup layer which will turn strings into integer indices
  if dtype == 'string':
    index = preprocessing.StringLookup(max_tokens=max_tokens)
  else:
    index = preprocessing.IntegerLookup(max_tokens=max_tokens)

  # Prepare a Dataset that only yields our feature
  feature_ds = dataset.map(lambda x, y: x[name])

  # Learn the set of possible values and assign them a fixed integer index.
  index.adapt(feature_ds)

  # Create a Discretization for our integer indices.
  encoder = preprocessing.CategoryEncoding(num_tokens=index.vocabulary_size())

  # Apply one-hot encoding to our indices. The lambda function captures the
  # layer so we can use them, or include them in the functional model later.
  return lambda feature: encoder(index(feature))

However the output isn't aligned with the guide. When my input to the layer is a list of n strings, instead of the output being shape (n, vocabulary size), I am receiving an output of shape (1, vocabulary size), with multiple categories incorrectly marked '1'. e.g. using n=2 and vocabulary size=3 Instead of getting an OHE of [[1, 0, 0], [0, 1, 0]], I am getting [1, 1, 0].

My code is exactly the same as the guide, but it looks like the layer is "merging" the encoding of each element of my input. Is there something wrong with the layer they provided or could someone give pointer on what I could test?

CodePudding user response:

By default, CategoryEncoding uses output_mode="multi_hot". That's why you're getting output of size (1, vocab_size). To get OHE of size (n, vocab_size), make this change in your code

encoder = preprocessing.CategoryEncoding(num_tokens=index.vocabulary_size(), output_mode='one_hot')

  • Related