Suppose x
contains segment ids, I want to give a unique id for every item inside every segment id. This needs to be executed in a tensorflow
operation
x = tf.constant([1, 1, 2, 2, 3, 3, 4, 1])
Needed output:
[0, 1, 0, 1, 0, 1, 0, 2]
Just counts every item inside every segment id. I don't want to use a py_func
.
CodePudding user response:
Try tf.unique_with_counts
with tf.while_loop
:
import tensorflow as tf
x = tf.constant([1, 1, 2, 2, 3, 3, 4, 1])
unique, _, count = tf.unique_with_counts(x)
i = tf.constant(0)
result = tf.zeros_like(x)
c = lambda i, result, unique, count: tf.less(i, tf.shape(unique)[0])
b = lambda i, r, u, c: (tf.add(i, 1), tf.tensor_scatter_nd_update(r, tf.where(tf.equal(u[i], x)), tf.range(c[i])), u, c)
_, result, _, _ = tf.while_loop(c, b, loop_vars=[i, result, unique, count])
print(result)
# tf.Tensor([0 1 0 1 0 1 0 2], shape=(8,), dtype=int32)
CodePudding user response:
point_ids = tf.zeros_like(x)
for i in range(n_pillars):
indicies = tf.cast(tf.where(tf.equal(x, i)), tf.int32)
updates = tf.range(len(indicies))
point_ids = tf.scatter_nd(indicies, updates, shape=tf.shape(idx))