Home > front end >  Lookup table in TensorFlow with key is string and value is list of strings
Lookup table in TensorFlow with key is string and value is list of strings

Time:04-02

I would like to generate a Lookup table in TensorFlow with key is string and value is list of strings. But it seems currently no classes in tf.lookup support this. Any ideas?

CodePudding user response:

I do not think there is an implementation for exactly that use case, but you can try combining tf.lookup.StaticHashTable and tf.gather to create your own custom lookup table. You just need to make sure that your keys and string lists are in the correct order. For example, key a corresponds to the first string list, key b corresponds to the second string list and so on. Here is a working example:

class TensorLookup:
  def __init__(self, keys, strings):
    self.keys = keys
    self.strings = strings
    self.table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(self.keys, tf.range(tf.shape(self.keys)[0])),
    default_value=-1)
  
  def lookup(self, key):
    index = self.table.lookup(key)
    return tf.cond(tf.reduce_all(tf.equal(index, -1)), lambda: tf.constant(['']), lambda: tf.gather(self.strings, index))

keys = tf.constant(['a', 'b', 'c', 'd', 'e'])
strings = tf.ragged.constant([['fish', 'eating', 'cats'], 
                              ['cats', 'everywhere'], 
                              ['you', 'are', 'a', 'fine', 'lad'], 
                              ['a', 'mountain', 'over', 'there'],
                              ['bravo', 'at', 'charlie'] 
                              ])

tensor_dict = TensorLookup(keys = keys, strings = strings)

print(tensor_dict.lookup(tf.constant('a')))
print(tensor_dict.lookup(tf.constant('b')))
print(tensor_dict.lookup(tf.constant('c')))
print(tensor_dict.lookup(tf.constant('r'))) # expected empty value since the r key does not exist
tf.Tensor([b'fish' b'eating' b'cats'], shape=(3,), dtype=string)
tf.Tensor([b'cats' b'everywhere'], shape=(2,), dtype=string)
tf.Tensor([b'you' b'are' b'a' b'fine' b'lad'], shape=(5,), dtype=string)
tf.Tensor([b''], shape=(1,), dtype=string)

I intentionally use a ragged tensor to accommodate different lengths of string lists.

  • Related