Home > Back-end >  Drop a row in a tensor if the sum of the elements is lower than some threshold
Drop a row in a tensor if the sum of the elements is lower than some threshold

Time:12-21

How can I drop rows in a tensor if the sum of the elements in each row is lower than the threshold -1? For example:

tensor = tf.random.normal((3, 3))
tf.Tensor(
[[ 0.506158    0.53865975 -0.40939444]
 [ 0.4917719  -0.1575156   1.2308844 ]
 [ 0.08580616 -1.1503975  -2.252681  ]], shape=(3, 3), dtype=float32)

Since the sum of the last row is smaller than -1, I need to remove it and get the tensor (2, 3):

tf.Tensor(
[[ 0.506158    0.53865975 -0.40939444]
 [ 0.4917719  -0.1575156   1.2308844 ]], shape=(2, 3), dtype=float32)

I know how to use tf.reduce_sum, but I do not know how to delete rows from a tensor. Something like df.drop would be nice.

CodePudding user response:

tf.boolean_mask is all you need.

tensor = tf.constant([
    [ 0.506158,    0.53865975, -0.40939444],
    [ 0.4917719,  -0.1575156,   1.2308844 ],
    [ 0.08580616, -1.1503975,  -2.252681  ],
])

mask = tf.reduce_sum(tensor, axis=1) > -1 
# <tf.Tensor: shape=(3,), dtype=bool, numpy=array([ True,  True, False])>

tf.boolean_mask(
    tensor=tensor, 
    mask=mask,
    axis=0
)
# <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
# array([[ 0.506158  ,  0.53865975, -0.40939444],
#        [ 0.4917719 , -0.1575156 ,  1.2308844 ]], dtype=float32)>

CodePudding user response:

You could use tf.where to extract the indices of the rows for which the sum of the elements is greater than -1, and then use tf.gather to drop the other rows.

import tensorflow as tf

tf.random.set_seed(0)

x = tf.random.normal((3, 3))
# <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
# array([[ 1.5110626 ,  0.42292204, -0.41969493],
#        [-1.0360372 , -1.2368279 ,  0.47027302],
#        [-0.01397489,  1.1888583 ,  0.60253334]], dtype=float32)>

x = tf.gather(x, indices=tf.squeeze(tf.where(tf.reduce_sum(x, axis=1) > -1)), axis=0)
# <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
# array([[ 1.5110626 ,  0.42292204, -0.41969493],
#        [-0.01397489,  1.1888583 ,  0.60253334]], dtype=float32)>
  • Related