For example the input:
a = tf.constant([1, 2, 3, 1, 2, 1, 4])
I want:
a_mask = [True, True, True, False, False, False, True]
CodePudding user response:
I don't know if this answer your question :
l = list(a.numpy())
# Init a set and the mask
se = set()
a_mask = []
for x in l:
if x in se:
a_mask.append(False)
else :
a_mask.append(True)
se.add(x)
CodePudding user response:
Try using tf.math.unsorted_segment_min
with tf.tensor_scatter_nd_update
:
import tensorflow as tf
a = tf.constant([1, 2, 3, 1, 2, 1, 4])
v, i = tf.unique(a)
indices = tf.math.unsorted_segment_min(tf.range(tf.shape(a)[0]), i, tf.shape(v)[0])
updates = tf.ones_like(indices, dtype=tf.bool)
mask = tf.tensor_scatter_nd_update(tf.zeros_like(a, dtype=tf.bool), tf.expand_dims(indices, axis=-1), updates)
print(mask)
tf.Tensor([ True True True False False False True], shape=(7,), dtype=bool)