1st: mask (11 true)
2nd: array (11 element)
3nd: zero_tensor (same shape as mask tensor): i want the position can be assigned corresponding element in array (position that was true in mask)
In torch, we can use zero_tensor[mask] = array
, but how to do in tensorflow?
CodePudding user response:
Maybe something like this:
import tensorflow as tf
x = tf.zeros((5, 3), dtype=tf.int32)
new_values = tf.constant([1, 1, 1, 2, 2, 2, 3, 4, 4, 5, 5], dtype=tf.int32)
mask = tf.constant([[True, True, True],
[True, True, True],
[True, False, False],
[True, True, False],
[True, True, False]])
print(x)
new_x = tf.tensor_scatter_nd_update(x, tf.where(mask), new_values)
print(new_x)
tf.Tensor(
[[0 0 0]
[0 0 0]
[0 0 0]
[0 0 0]
[0 0 0]], shape=(5, 3), dtype=int32)
tf.Tensor(
[[1 1 1]
[2 2 2]
[3 0 0]
[4 4 0]
[5 5 0]], shape=(5, 3), dtype=int32)