Home > Mobile >  Sort Tensorflow HashTable by value
Sort Tensorflow HashTable by value

Time:02-24

My Code :

h_table = tf.lookup.StaticHashTable(
      initializer=tf.lookup.KeyValueTensorInitializer(
          keys=[0, 1, 2, 3, 4, 5],
          values=[12.3, 11.1, 51.5, 34.3, 87.3, 57.8]
      ),
      default_value=tf.constant(-1),
      name="h_table"
    )

I wanted to sort this h_table by values so that new hash table is:

keys = [4, 5, 2, 3, 0, 1] values = [87.3, 57.8, 51.5, 34.3, 12.3, 11.1]

Equivalent process in python is :

h_table = { 0: 12.3, 1: 11.1, 2: 57.8, 3: 34.3, 4: 87.3, 5: 57.8}
h_sorted = dict(sorted(h_table.items(), key=lambda x: x[1], reverse=True))

All i want is implement this type of dictionary operation in tensorflow with tensor?

CodePudding user response:

You will have to create a new tf.lookup.StaticHashTable, since it is immutable once initialized:

import tensorflow as tf

h_table = tf.lookup.StaticHashTable(
      initializer=tf.lookup.KeyValueTensorInitializer(
          keys=[0, 1, 2, 3, 4, 5],
          values=[12.3, 11.1, 51.5, 34.3, 87.3, 57.8]
      ),
      default_value=tf.constant(-1.),
      name="h_table"
    )

keys = h_table._initializer._keys
values = h_table._initializer._values

value_indices = tf.argsort(tf.reverse(values, axis=[0]), -1)
keys = tf.gather(keys, value_indices)

new_h_table = tf.lookup.StaticHashTable(
      initializer=tf.lookup.KeyValueTensorInitializer(
          keys=keys,
          values=h_table.lookup(keys)
      ),
      default_value=tf.constant(-1.),
      name="new_h_table"
    )

print(new_h_table._initializer._keys)
print(new_h_table._initializer._values)
tf.Tensor([4 5 2 3 0 1], shape=(6,), dtype=int32)
tf.Tensor([87.3 57.8 51.5 34.3 12.3 11.1], shape=(6,), dtype=float32)
  • Related