I am trying to perform some transformations on a batch of images of size [128,128,3] with a batch_size of 20. I have managed to do part of the transformations on a single image but I'm not sure how to extend the same for the entire batch.
x
is the batched image tensor of shape [20,128,128,3]
grads
is also a tensor of shape [20,128,128,3]
I want to find the maximum value of grads
per image ie I will get 20 max values. Let us call the max value for image_i as max_grads_i.
add_max_grads
is a tensor of shape [20,128,128,3]
I want to update the value of add_max_grads
such that for image_i I need to update the corresponding add_max_grads
at the position of the max_grads_i in grads
tensor and I need to update this across the 3 channels.
For instance, for image i the shape is [128,128,3], if the max_grads_i
was found at grads[50][60][1]
, I want to update add_max_grads
at [50][60][0], [50][60][1], [50][60][2] with the same constant value.
Eventually, this needs to be extended for all the images in the batch to create add_max_grads
which has a shape [20,128,128,3].
How can I achieve this?
Currently I am using tf.math.reduce_max(grads,(1,2,3))
to find the max value in the grads
tensor for each image in the batch which returns a tensor of shape [20,] but I'm stuck at the rest.
CodePudding user response:
I think this does what you're asking, if I've understood it right.
This is a self-contained example.
import tensorflow as tf
#tshape = (20,128,128,3) # The real shape needed
tshape = (2,4,4,3) # A small shape, to allow print output to be readable
# A sample x vector, set to zeros here so we can see what gets updated easily
x = tf.zeros(shape=tshape, dtype=tf.float32)
# grads, same shape as x, let's use a random matrix here
grads = tf.random.normal(shape=tshape)
print(f"grads: \n{grads}")
# ---- Now calculate the indices we will use in tf.tensor_scatter_nd_update
# 4-D mask tensor (m, x, y, c) with 1 in max locations (by m), zeros otherwise
mask4d = tf.cast(grads == tf.reduce_max(grads, axis=(1,2,3), keepdims=True), dtype=tf.int32)
# 3D mask tensor (m, x, y) with 1 in max pixel locations (by m) across channel
mask3d = tf.reduce_max(mask4d, axis=3)
# indices of maximum values, each item gives m, x, y
indices = tf.where(mask3d)
print(f"\nindices\n{indices}")
# ---- Now calculate the updates we will use in tf.tensor_scatter_nd_update. This has shape (m, c)
newval = 999
updates = tf.constant(newval, shape=(tshape[0], tshape[3]), dtype=tf.float32)
# ---- Now modify x by scattering newval through it. Each entry in indices indicates a slice x[m, x, y, :]
x_updated = tf.tensor_scatter_nd_update(x, indices, updates)
print(f"\nx_updated\n{x_updated}")