I was just wondering that if there is a better way of updating tensor in tf2. Let's say that I have tensor_a = tf.ones(4,5,5)
(batch_size, H, W) and I would like to replace all the values of second sample by zeros(index=1
). This is how I manage to do it without using Eager execution mode:
tensor_a = tf.ones([4,5,5])
tensor_b = tf.zeros([1,5,5])
index=1
tensor_a = tf.concat([tensor_a[:index], tensor_b, tensor_a[index 1:]], axis=0)
I know that there exists tf.tensor_scatter_nd_update() function but I'm not familiar with meshgrids and in my opinion they look a little bit ugly for simple slice assignment operation. Also in some cases it would be handy to update slices with many indexes (like samples 0,1 and 2 to zeros) at once.
CodePudding user response:
Tensorflow operations are a bit messy sometimes.
import tensorflow as tf
tensor = tf.ones([4, 5, 5])
tensor = tf.tensor_scatter_nd_update(
tensor, [[1]], tf.zeros_like(tf.gather(tensor, [1]))
)
<tf.Tensor: shape=(4, 5, 5), dtype=float32, numpy=
array([[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]], dtype=float32)>