Only using tensorflow, how can I select rows of a tensor that satisfy a condition?
Example tensor x:
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[0, 1, 2],
[1, 1, 2],
[0, 1, 4]], dtype=int32)>
I'd like to create a new Tensor that only includes those rows of x where the first row element equals 0
CodePudding user response:
import tensorflow as tf
x = tf.constant([[0, 1, 2], [1, 1, 2], [0, 1, 4]])
x = tf.constant([i for i in x.numpy() if i[0] == 0)
Or only with tensorflow:
a = tf.constant([[0, 1, 2], [1, 1, 2], [0, 1, 4]])
mask = tf.where(a[:,0] == 0, True, False)
a = tf.boolean_mask(a, mask)