Home > Net >  updating an entire 3x3 patch based on a center index, (efficiently) - tensorflow
updating an entire 3x3 patch based on a center index, (efficiently) - tensorflow

Time:02-19

Say I have the following tensor -

[[[[ 1.1008097   1.4609661  -0.07704023]
   [ 0.51914555 -0.44149    -0.49901748]
   [-1.5465652  -1.066219    0.72451854]]

  [[-0.48539782 -1.6190584  -0.69578767]
   [-1.3168293   0.7194597   0.23308933]
   [ 0.95422655 -0.8992636  -0.08333881]]]], shape=(1, 2, 3, 3), dtype=float32)

Which is as I understand it, it's two 3x3 matrices.

And I want to edit the entire patch to 0.0 which the center of it, is the index: [0 0 1 1] ( the -0.44149 element)

I'm looking for a way to do it efficiently, Something like:

tf.tensor_scatter_nd_update(inputs, ([index[0], index[1], index[2]-1 : index[2] 1, index[3]-1 : index[3] 1] ), tf.constant(0.0))

where Index is an array [0, 0, 1, 1] .

so the wanted result will be:

[[[[ 0.0   0.0    0.0]
   [ 0.0   0.0    0.0]
   [ 0.0   0.0    0.0]]

  [[-0.48539782 -1.6190584  -0.69578767]
   [-1.3168293   0.7194597   0.23308933]
   [ 0.95422655 -0.8992636  -0.08333881]]]], shape=(1, 2, 3, 3), dtype=float32)

eventually I want to do this on a bigger tensor with a shape of (1, 16, 30, 30) for example

CodePudding user response:

IIUC, you can try something like this:

x = tf.constant([[[[ 1.1008097,   1.4609661,  -0.07704023],
   [ 0.51914555, -0.44149,    -0.49901748],
   [-1.5465652,  -1.066219,    0.72451854]],

  [[-0.48539782, -1.6190584 , -0.69578767],
   [-1.3168293,   0.7194597,   0.23308933],
   [ 0.95422655, -0.8992636,  -0.08333881]]]])

index = [0, 0, 1, 1]

ij = tf.stack(tf.meshgrid(
    tf.range(tf.shape(x)[2], dtype=tf.int32), 
    tf.range(tf.shape(x)[3], dtype=tf.int32),
                              indexing='ij'), axis=-1)
first_dims = tf.constant([index[0], index[1]])
ij_shape = tf.shape(ij)
indices = tf.reshape(tf.repeat(first_dims, repeats=ij_shape[0]*ij_shape[1]), (ij_shape[0], ij_shape[1], tf.shape(first_dims)[0]))
indices = tf.concat([indices, ij], axis=-1)
indices_shape = tf.shape(indices)
indices = tf.reshape(indices, (indices_shape[0]*indices_shape[1], indices_shape[2]))
updates = tf.repeat(tf.constant(0.0), repeats=ij_shape[0]*ij_shape[1])

print(tf.tensor_scatter_nd_update(x, indices, updates))
tf.Tensor(
[[[[ 0.          0.          0.        ]
   [ 0.          0.          0.        ]
   [ 0.          0.          0.        ]]

  [[-0.48539782 -1.6190584  -0.69578767]
   [-1.3168293   0.7194597   0.23308933]
   [ 0.95422655 -0.8992636  -0.08333881]]]], shape=(1, 2, 3, 3), dtype=float32)
  • Related