Home > Software design >  How to slice according to batch in the tensorflow array?
How to slice according to batch in the tensorflow array?

Time:01-03

I have an array output and a id subject_ids.

output = [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]]]

subject_ids = [[0, 1], [1, 2], [0, 2]]

The numbers in ID represent the start and end positions respectively, and then I want to get the vector between them according to the start and end positions.

For example,I should get [[1, 2, 3], [4, 5, 6]] and [[4, 5, 6], [7, 8, 9]] and [[1, 2, 3], [4, 5, 6], [7, 8, 9]] in this case.

What should I do? I tried tf.slice and tf.gather, but it didn't seem to work.

CodePudding user response:

How about just

>>> [output[np.arange(x, y 1)] for x, y in subject_ids] 
[array([[[1, 2, 3]],
        [[4, 5, 6]]]),
        
 array([[[4, 5, 6]],
        [[7, 8, 9]]]),
        
 array([[[1, 2, 3]],
        [[4, 5, 6]],
        [[7, 8, 9]]])]

CodePudding user response:

If you want to use Tensorflow only, try combining tf.gather with tf.range and tf.ragged.stack:

import tensorflow as tf

output = tf.constant([
                      [[1, 2, 3]], 
                      [[4, 5, 6]], 
                      [[7, 8, 9]]
                      ])

subject_ids = tf.constant([[0, 1], [1, 2], [0, 2]])

ragged_ouput = tf.ragged.stack([tf.gather(output, tf.range(subject_ids[i, 0], subject_ids[i, 1]   1)) for i in tf.range(0, tf.shape(subject_ids)[0])], axis=0)
ragged_ouput = tf.squeeze(ragged_ouput, axis=2)
print(ragged_ouput)
[[[1, 2, 3], [4, 5, 6]], [[4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
  • Related