Home > OS >  Perform image transformations on batches of images in tensorflow
Perform image transformations on batches of images in tensorflow

Time:09-01

I am trying to perform some transformations on a batch of images of size [128,128,3] with a batch_size of 20. I have managed to do part of the transformations on a single image but I'm not sure how to extend the same for the entire batch.

x is the batched image tensor of shape [20,128,128,3]

grads is also a tensor of shape [20,128,128,3]

I want to find the maximum value of grads per image ie I will get 20 max values. Let us call the max value for image_i as max_grads_i.

add_max_grads is a tensor of shape [20,128,128,3]

I want to update the value of add_max_grads such that for image_i I need to update the corresponding add_max_grads at the position of the max_grads_i in grads tensor and I need to update this across the 3 channels.

For instance, for image i the shape is [128,128,3], if the max_grads_i was found at grads[50][60][1], I want to update add_max_grads at [50][60][0], [50][60][1], [50][60][2] with the same constant value.

Eventually, this needs to be extended for all the images in the batch to create add_max_grads which has a shape [20,128,128,3].

How can I achieve this?

Currently I am using tf.math.reduce_max(grads,(1,2,3)) to find the max value in the grads tensor for each image in the batch which returns a tensor of shape [20,] but I'm stuck at the rest.

CodePudding user response:

I think this does what you're asking, if I've understood it right.

This is a self-contained example.

import tensorflow as tf
#tshape = (20,128,128,3)     # The real shape needed
tshape = (2,4,4,3)           # A small shape, to allow print output to be readable

# A sample x vector, set to zeros here so we can see what gets updated easily
x = tf.zeros(shape=tshape, dtype=tf.float32)

# grads, same shape as x, let's use a random matrix here
grads = tf.random.normal(shape=tshape)
print(f"grads: \n{grads}")

# ---- Now calculate the indices we will use in tf.tensor_scatter_nd_update
# 4-D mask tensor (m, x, y, c) with 1 in max locations (by m), zeros otherwise
mask4d = tf.cast(grads == tf.reduce_max(grads, axis=(1,2,3), keepdims=True), dtype=tf.int32)
# 3D mask tensor (m, x, y) with 1 in max pixel locations (by m) across channel
mask3d = tf.reduce_max(mask4d, axis=3)
# indices of maximum values, each item gives m, x, y
indices = tf.where(mask3d)
print(f"\nindices\n{indices}")

# ---- Now calculate the updates we will use in tf.tensor_scatter_nd_update. This has shape (m, c)
newval = 999
updates = tf.constant(newval, shape=(tshape[0], tshape[3]), dtype=tf.float32)

# ---- Now modify x by scattering newval through it. Each entry in indices indicates a slice x[m, x, y, :]
x_updated = tf.tensor_scatter_nd_update(x, indices, updates)
print(f"\nx_updated\n{x_updated}")

  • Related