- I have a 3D tensor
X
, of shape[7, 240, 768]
. - I have another tensor
mask_idx
of shape[7, 240]
which contains0/False
and1/True
, where0/False
means I don't want to update the value inX[i][j]
and1/True
means I want to do thisX[i][j] = tf.zeros([768])
.
I have tried using tf.where(mask_idx, tf.zeros([7, 240, 768]), X)
but getting this error:
*** tensorflow.python.framework.errors_impl.InvalidArgumentError: condition [7,240], then [7,240,768], and else [7,240,768] must be broadcastable [Op:SelectV2]
Can anyone suggest the correct approach to it?
CodePudding user response:
TF checks if dimensions are broadcastable from right to left, so one simple way is to expand your mask tensor in the last dimension, i.e., make its shape (7,240,1)
.
tf.where(mask_idx[...,None], 0, X)