Home > Enterprise >  How to index values in Tensorflow based on a Tensor?
How to index values in Tensorflow based on a Tensor?

Time:01-29

I have three tensors given below in tensorflow v1:

mod_labels = [0,1,1,0,0]
feats = [[1,2,1], [3,2,6], [1,1,1], [9,8,4], [5,4,8]]
labels = [1,53,12,89,54]

I want to create four new tensors based on values from mod_labels as:

# Make new tensors for mod_labels=0
mod0_feats = [[1,2,1], [9,8,4], [5,4,8]]
mod0_labels = [1,89,54]

# Similarly make tensors for mod_labels=1
mod1_feats = [[3,2,6], [1,1,1]]
mod1_labels = [53,12]

I have tried using for loop to iterate over mod_labels but tensorflow does not allow to iterate over placeholders.

CodePudding user response:

well, supposing those are tensors:

mod_labels = tf.convert_to_tensor([0,1,1,0,0])
feats = tf.convert_to_tensor([[1,2,1], [3,2,6], [1,1,1], [9,8,4], [5,4,8]])
labels = tf.convert_to_tensor([1,53,12,89,54])
zeros = tf.where(mod_labels == 0)
ones = tf.where(mod_labels == 1)
mod0_feats = tf.gather_nd(feats, zeros)
mod0_labels = tf.gather_nd(labels, zeros)
mod1_feats = tf.gather_nd(feats, ones)
mod1_labels = tf.gather_nd(labels, ones)
  • Related