Home > Blockchain >  What is more sophisticated way to assign to tensor in tensorflow 2.x
What is more sophisticated way to assign to tensor in tensorflow 2.x

Time:11-15

I was just wondering that if there is a better way of updating tensor in tf2. Let's say that I have tensor_a = tf.ones(4,5,5) (batch_size, H, W) and I would like to replace all the values of second sample by zeros(index=1). This is how I manage to do it without using Eager execution mode:

tensor_a = tf.ones([4,5,5])
tensor_b = tf.zeros([1,5,5])
index=1
tensor_a = tf.concat([tensor_a[:index], tensor_b, tensor_a[index 1:]], axis=0)

I know that there exists tf.tensor_scatter_nd_update() function but I'm not familiar with meshgrids and in my opinion they look a little bit ugly for simple slice assignment operation. Also in some cases it would be handy to update slices with many indexes (like samples 0,1 and 2 to zeros) at once.

CodePudding user response:

Tensorflow operations are a bit messy sometimes.

import tensorflow as tf

tensor = tf.ones([4, 5, 5])

tensor = tf.tensor_scatter_nd_update(
    tensor, [[1]], tf.zeros_like(tf.gather(tensor, [1])) 
)
<tf.Tensor: shape=(4, 5, 5), dtype=float32, numpy=
array([[[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]],
       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],
       [[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]],
       [[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]]], dtype=float32)>
  • Related