Say I have the following tensor -
[[[[ 1.1008097 1.4609661 -0.07704023]
[ 0.51914555 -0.44149 -0.49901748]
[-1.5465652 -1.066219 0.72451854]]
[[-0.48539782 -1.6190584 -0.69578767]
[-1.3168293 0.7194597 0.23308933]
[ 0.95422655 -0.8992636 -0.08333881]]]], shape=(1, 2, 3, 3), dtype=float32)
Which is as I understand it, it's two 3x3 matrices.
And I want to edit the entire patch to 0.0
which the center of it, is the index: [0 0 1 1]
( the -0.44149 element)
I'm looking for a way to do it efficiently, Something like:
tf.tensor_scatter_nd_update(inputs, ([index[0], index[1], index[2]-1 : index[2] 1, index[3]-1 : index[3] 1] ), tf.constant(0.0))
where Index is an array [0, 0, 1, 1]
.
so the wanted result will be:
[[[[ 0.0 0.0 0.0]
[ 0.0 0.0 0.0]
[ 0.0 0.0 0.0]]
[[-0.48539782 -1.6190584 -0.69578767]
[-1.3168293 0.7194597 0.23308933]
[ 0.95422655 -0.8992636 -0.08333881]]]], shape=(1, 2, 3, 3), dtype=float32)
eventually I want to do this on a bigger tensor with a shape of (1, 16, 30, 30)
for example
CodePudding user response:
IIUC, you can try something like this:
x = tf.constant([[[[ 1.1008097, 1.4609661, -0.07704023],
[ 0.51914555, -0.44149, -0.49901748],
[-1.5465652, -1.066219, 0.72451854]],
[[-0.48539782, -1.6190584 , -0.69578767],
[-1.3168293, 0.7194597, 0.23308933],
[ 0.95422655, -0.8992636, -0.08333881]]]])
index = [0, 0, 1, 1]
ij = tf.stack(tf.meshgrid(
tf.range(tf.shape(x)[2], dtype=tf.int32),
tf.range(tf.shape(x)[3], dtype=tf.int32),
indexing='ij'), axis=-1)
first_dims = tf.constant([index[0], index[1]])
ij_shape = tf.shape(ij)
indices = tf.reshape(tf.repeat(first_dims, repeats=ij_shape[0]*ij_shape[1]), (ij_shape[0], ij_shape[1], tf.shape(first_dims)[0]))
indices = tf.concat([indices, ij], axis=-1)
indices_shape = tf.shape(indices)
indices = tf.reshape(indices, (indices_shape[0]*indices_shape[1], indices_shape[2]))
updates = tf.repeat(tf.constant(0.0), repeats=ij_shape[0]*ij_shape[1])
print(tf.tensor_scatter_nd_update(x, indices, updates))
tf.Tensor(
[[[[ 0. 0. 0. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]]
[[-0.48539782 -1.6190584 -0.69578767]
[-1.3168293 0.7194597 0.23308933]
[ 0.95422655 -0.8992636 -0.08333881]]]], shape=(1, 2, 3, 3), dtype=float32)