Home > database >  How to catch the first matching element in TensorFlow
How to catch the first matching element in TensorFlow

Time:03-08

For example the input:

a = tf.constant([1, 2, 3, 1, 2, 1, 4])

I want:

a_mask = [True, True, True, False, False, False, True]

CodePudding user response:

I don't know if this answer your question :

l = list(a.numpy())

# Init a set and the mask
se = set()
a_mask = []

for x in l:
  if x in se:
    a_mask.append(False)
  else :
    a_mask.append(True)
  se.add(x)

CodePudding user response:

Try using tf.math.unsorted_segment_min with tf.tensor_scatter_nd_update:

import tensorflow as tf

a = tf.constant([1, 2, 3, 1, 2, 1, 4])
v, i = tf.unique(a)
indices = tf.math.unsorted_segment_min(tf.range(tf.shape(a)[0]), i, tf.shape(v)[0])
updates = tf.ones_like(indices, dtype=tf.bool)
mask = tf.tensor_scatter_nd_update(tf.zeros_like(a, dtype=tf.bool), tf.expand_dims(indices, axis=-1), updates)
print(mask)
tf.Tensor([ True  True  True False False False  True], shape=(7,), dtype=bool)
  • Related