Home > front end >  How to select rows of tensor based on condition (tensorflow)
How to select rows of tensor based on condition (tensorflow)

Time:07-04

Only using tensorflow, how can I select rows of a tensor that satisfy a condition?

Example tensor x:

<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[0, 1, 2],
       [1, 1, 2],
       [0, 1, 4]], dtype=int32)>

I'd like to create a new Tensor that only includes those rows of x where the first row element equals 0

CodePudding user response:

import tensorflow as tf

x = tf.constant([[0, 1, 2], [1, 1, 2], [0, 1, 4]])
x = tf.constant([i for i in x.numpy() if i[0] == 0)

Or only with tensorflow:

a = tf.constant([[0, 1, 2], [1, 1, 2], [0, 1, 4]])
mask = tf.where(a[:,0] == 0, True, False)
a = tf.boolean_mask(a, mask)

  • Related