Home > Mobile >  Update a tensor based on the highest value of another tensor when passed through a mask
Update a tensor based on the highest value of another tensor when passed through a mask

Time:08-31

I'm learning some transformations in tensorflow and want to know what are my possible ways to achieve the following.

tensor_1 = [0, 1, 2, 3]
tensor_2 = [0, 0, 0, 0]
mask = [True, False, True, False]

Expected outcome tensor_2 = [0, 0, 1, 0]. Essentially, I want to pass tensor_1 through the mask and for values that are True, I want to update the same index in tensor_2 as 1. For example, in the example above, the highest value when passed through the mask is 2, so we update the third index in tensor_2.

Also, I need to do this for a batch of images (in our example that would be tensor_1) of shape (batch_size, 128, 128, 3) where each image has 3 channels. We need to find the maximum in the flattened image (128, 128, 3) and apply the transformation in all the 3 channels of that pixel in tensor_2 such that for that pixel, we have 1 in all 3 channels, final shape being (batch_size, 128, 128, 3). The mask is also of shape (batch_size, 128, 128, 3).

I understand that this is very specific but I want to understand transformations and not sure where to begin but try out some scenarios.

CodePudding user response:

Making some assumptions about your question because there's a bit of contradictory information. I will update this answer if you feel I missed something.

I want to pass tensor_1 through the mask and for values that are True, I want to update the same index in tensor_2 as 1

If you just want to use a preexisting mask to update values, you can use tf.where

tensor_1 = [0, 1, 2, 3]
tensor_2 = [0, 0, 0, 0]
mask = [True, False, True, False]

tf.where(mask,1,tensor_2)
>>>
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 0, 1, 0], dtype=int32)>

For example, in the example above, the highest value when passed through the mask is 2, so we update the third index in tensor_2

The example you provided and your end goal don't match. The highest value in tensor_1 is also 3, not 2. But you can use tf.where and tf.reduce_max directly without having to create a separate mask.

tf.where(tensor_1 == tf.math.reduce_max(tensor_1),1,tensor_2
         )
>>>
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 0, 0, 1], dtype=int32)>

You can also do this to your 3D-tensor without needing to flatten it

xx = tf.random.uniform((2,2,3),maxval=10,dtype=tf.int32)

xx
>>>
<tf.Tensor: shape=(2, 2, 3), dtype=int32, numpy=
array([[[9, 8, 6],
        [7, 0, 2]],

       [[3, 8, 2],
        [2, 6, 7]]], dtype=int32)>


tf.where(xx == tf.math.reduce_max(xx),-5,0
         )

>>>
<tf.Tensor: shape=(2, 2, 3), dtype=int32, numpy=
array([[[-5,  0,  0],
        [ 0,  0,  0]],

       [[ 0,  0,  0],
        [ 0,  0,  0]]], dtype=int32)>
  • Related