Home > Back-end >  Tensorflow: Replace indices in 3D tensor using 2D tensor
Tensorflow: Replace indices in 3D tensor using 2D tensor

Time:09-16

  • I have a 3D tensor X, of shape [7, 240, 768].
  • I have another tensor mask_idx of shape [7, 240] which contains 0/False and 1/True, where 0/False means I don't want to update the value in X[i][j] and 1/True means I want to do this X[i][j] = tf.zeros([768]).

I have tried using tf.where(mask_idx, tf.zeros([7, 240, 768]), X) but getting this error:

*** tensorflow.python.framework.errors_impl.InvalidArgumentError: condition [7,240], then [7,240,768], and else [7,240,768] must be broadcastable [Op:SelectV2]

Can anyone suggest the correct approach to it?

CodePudding user response:

TF checks if dimensions are broadcastable from right to left, so one simple way is to expand your mask tensor in the last dimension, i.e., make its shape (7,240,1).

tf.where(mask_idx[...,None], 0, X)
  • Related