Home > Net >  How to get a unique id for every occurrence of an item inside a tensor of segment_ids in Tensorflow
How to get a unique id for every occurrence of an item inside a tensor of segment_ids in Tensorflow

Time:03-23

Suppose x contains segment ids, I want to give a unique id for every item inside every segment id. This needs to be executed in a tensorflow operation

x = tf.constant([1, 1, 2, 2, 3, 3, 4, 1])

Needed output:

[0, 1, 0, 1, 0, 1, 0, 2]

Just counts every item inside every segment id. I don't want to use a py_func.

CodePudding user response:

Try tf.unique_with_counts with tf.while_loop:

import tensorflow as tf

x = tf.constant([1, 1, 2, 2, 3, 3, 4, 1])

unique, _, count = tf.unique_with_counts(x)

i = tf.constant(0)
result = tf.zeros_like(x)
c = lambda i, result, unique, count: tf.less(i, tf.shape(unique)[0])
b = lambda i, r, u, c: (tf.add(i, 1), tf.tensor_scatter_nd_update(r, tf.where(tf.equal(u[i], x)), tf.range(c[i])), u, c)

_, result, _, _ = tf.while_loop(c, b, loop_vars=[i, result, unique, count])

print(result)
# tf.Tensor([0 1 0 1 0 1 0 2], shape=(8,), dtype=int32)

CodePudding user response:

point_ids = tf.zeros_like(x)
for i in range(n_pillars):
    indicies = tf.cast(tf.where(tf.equal(x, i)), tf.int32)
    updates = tf.range(len(indicies))
    point_ids  = tf.scatter_nd(indicies, updates, shape=tf.shape(idx))
  • Related