Let's suppose we have 2 tensors like
A = [[1, 2, 3, 4],
[5, 6, 7, 8]]
B = [[True, True, True, True],
[True, False, True, True]]
I want to extract K left-most columns from A where its corresponding boolean mask in B is True. In the above example, if K=2, the results should be
C = [[1, 2],
[5, 7]]
6 is not included in C because its corresponding boolean mask is False
.
I was able to do that with the following code:
batch_size = 2
C = tf.zeros((batch_size, K), tf.int32)
for batch_idx in tf.range(batch_size):
a = A[batch_idx]
b = B[batch_idx]
tmp = tf.boolean_mask(a, b)
tmp = tmp[:K]
C = tf.tensor_scatter_nd_update(
C, [[batch_idx]], tf.expand_dims(tmp, axis=0))
But I don't want to iterate over A and B with for loop. Is there any way to do this with matrix operators only?
CodePudding user response:
Not sure if it will work for all corner cases, but you could try using a tf.ragged.boolean_mask
import tensorflow as tf
A = [[1, 2, 3, 4],
[5, 6, 7, 8]]
B = [[True, True, True, True],
[True, False, True, True]]
K = 2
tmp = tf.ragged.boolean_mask(A, B)
C = tmp[:, :K].to_tensor()
tf.Tensor(
[[1 2]
[5 7]], shape=(2, 2), dtype=int32)
K = 3:
tf.Tensor(
[[1 2 3]
[5 7 8]], shape=(2, 3), dtype=int32)