Home > Software engineering >  How to update a pixel value across 3 channels in tensorflow
How to update a pixel value across 3 channels in tensorflow

Time:08-31

The goal I'm trying to achieve is to update a pixel in all 3 RGB channels of a Tensor. For example,

t1 = tf.zeros([4, 4, 3], tf.int32)

t1 is a tensor with 4x4 as (x,y) of the pixel and 3 channels. I use tf.where to find the index of the highest value in another tensor of same shape like

t2 = tf.random.uniform([4, 4, 3], maxval=10, dtype=tf.int32)
tf.where(tf.math.reduce_max(t2), 1, t1)

But this only changes the value of the single index which had the highest value in t2. What I'm looking for is to update that (x,y) in all 3 channels.

For instance if t2 looked like Channel 1 -> [[1 2],[3,4]] Channel 2 -> [[5 6],[7,8]] Channel 2 -> [[1 0],[0,0]] The max value is 8 which is the (1,1) index within the channel.

I would like then t1 to look like
Channel 1 -> [[0 0],[0,1]] Channel 2 -> [[0 0],[0,1]] Channel 2 -> [[0 0],[0,1]]

How can I achieve this?

CodePudding user response:

This will do it , I think. You need to reduce down to a 2-D tensor and then tile it across the channel dimension.

# Sample tensor for demo purposes with integers 0-7, the max will be in t2[1,1,1]
# Channel is dimension 2
t2 = tf.reshape([i for i in range(8)],(2,2,2))
print(f"{t2=}")

mask3 = tf.cast(t2 == tf.reduce_max(t2), tf.int32)     # 3-D integer mask, identifies the cell with max value
mask2 = tf.reduce_max(mask3, axis=2, keepdims=True)    # Reduce across channel dimension, find pixel with max value
t1 = tf.tile(mask2, (1, 1, t2.shape[2]))              # Tile to get 1 in target pixel across all channels
print(f"\n{t1=}")

CodePudding user response:

First, we get the (partial) indices to update as

idx = tf.where(t2 == tf.reduce_max(t2))[..., :-1]

or if there can be multiple positions with max value and you only want to use the first one, then

idx = tf.where(t2 == tf.reduce_max(t2))[:1, ..., :-1]

Having the partial indices, we can update t1 using tf.tensor_scatter_nd_update.

t1 = tf.tensor_scatter_nd_update(t1, idx, tf.ones((idx.shape[0],t1.shape[-1]), tf.int32))

This is applicable to any tensor with two or more dimensions, with the last dimension being the "channel" dimension.

  • Related