Home > Back-end >  Extract K left-most columns from a 2D tensor with mask
Extract K left-most columns from a 2D tensor with mask

Time:02-15

Let's suppose we have 2 tensors like

A = [[1, 2, 3, 4],
     [5, 6, 7, 8]]
B = [[True, True, True, True],
     [True, False, True, True]]

I want to extract K left-most columns from A where its corresponding boolean mask in B is True. In the above example, if K=2, the results should be

C = [[1, 2],
     [5, 7]]

6 is not included in C because its corresponding boolean mask is False.

I was able to do that with the following code:

batch_size = 2
C = tf.zeros((batch_size, K), tf.int32)

for batch_idx in tf.range(batch_size):
    a = A[batch_idx]
    b = B[batch_idx]

    tmp = tf.boolean_mask(a, b)
    tmp = tmp[:K]

    C = tf.tensor_scatter_nd_update(
        C, [[batch_idx]], tf.expand_dims(tmp, axis=0))

But I don't want to iterate over A and B with for loop. Is there any way to do this with matrix operators only?

CodePudding user response:

Not sure if it will work for all corner cases, but you could try using a tf.ragged.boolean_mask

import tensorflow as tf

A = [[1, 2, 3, 4],
     [5, 6, 7, 8]]
B = [[True, True, True, True],
     [True, False, True, True]]
K = 2
tmp = tf.ragged.boolean_mask(A, B)
C = tmp[:, :K].to_tensor()
tf.Tensor(
[[1 2]
 [5 7]], shape=(2, 2), dtype=int32)

K = 3:

tf.Tensor(
[[1 2 3]
 [5 7 8]], shape=(2, 3), dtype=int32)
  • Related