The goal I'm trying to achieve is to update a pixel in all 3 RGB channels of a Tensor. For example,
t1 = tf.zeros([4, 4, 3], tf.int32)
t1
is a tensor with 4x4 as (x,y) of the pixel and 3 channels. I use tf.where
to find the index of the highest value in another tensor of same shape like
t2 = tf.random.uniform([4, 4, 3], maxval=10, dtype=tf.int32)
tf.where(tf.math.reduce_max(t2), 1, t1)
But this only changes the value of the single index which had the highest value in t2
.
What I'm looking for is to update that (x,y) in all 3 channels.
For instance if t2
looked like
Channel 1 -> [[1 2],[3,4]]
Channel 2 -> [[5 6],[7,8]]
Channel 2 -> [[1 0],[0,0]]
The max value is 8 which is the (1,1) index within the channel.
I would like then t1
to look like
Channel 1 -> [[0 0],[0,1]]
Channel 2 -> [[0 0],[0,1]]
Channel 2 -> [[0 0],[0,1]]
How can I achieve this?
CodePudding user response:
This will do it , I think. You need to reduce down to a 2-D tensor and then tile it across the channel dimension.
# Sample tensor for demo purposes with integers 0-7, the max will be in t2[1,1,1]
# Channel is dimension 2
t2 = tf.reshape([i for i in range(8)],(2,2,2))
print(f"{t2=}")
mask3 = tf.cast(t2 == tf.reduce_max(t2), tf.int32) # 3-D integer mask, identifies the cell with max value
mask2 = tf.reduce_max(mask3, axis=2, keepdims=True) # Reduce across channel dimension, find pixel with max value
t1 = tf.tile(mask2, (1, 1, t2.shape[2])) # Tile to get 1 in target pixel across all channels
print(f"\n{t1=}")
CodePudding user response:
First, we get the (partial) indices to update as
idx = tf.where(t2 == tf.reduce_max(t2))[..., :-1]
or if there can be multiple positions with max value and you only want to use the first one, then
idx = tf.where(t2 == tf.reduce_max(t2))[:1, ..., :-1]
Having the partial indices, we can update t1
using tf.tensor_scatter_nd_update
.
t1 = tf.tensor_scatter_nd_update(t1, idx, tf.ones((idx.shape[0],t1.shape[-1]), tf.int32))
This is applicable to any tensor with two or more dimensions, with the last dimension being the "channel" dimension.