A little example to demonstrate what I need
I have a question about gathering in tensorflow. Let's say I have a tensor of values (that I care about for some reason):
test1 = tf.round(5*tf.random.uniform(shape=(2,3)))
which gives me this output:
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 2.],
[4., 5., 0.]], dtype=float32)>
and I also have a tensor of indices column indices that I want to pick out on every row:
test_ind = tf.constant([[0,1,0,0,1],
[0,1,1,1,0]], dtype=tf.int64)
I want to gather this so that from the first row (0th row), I pick out items in column 0, 1, 0, 0, 1, and same for the second row.
So the output for this example should be:
<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[1., 1., 1., 1., 1.],
[4., 5., 5., 5., 4.]], dtype=float32)>
My attempt
So I figured out a way to do this in general, I wrote the following function gather_matrix_indices() that will take in a tensor of values and a tensor of indices and do exactly what I specified above.
def gather_matrix_indices(input_arr, index_arr):
row, _ = input_arr.shape
li = []
for i in range(row):
li.append(tf.expand_dims(tf.gather(params=input_arr[i], indices=index_arr[i]), axis=0))
return tf.concat(li, axis=0)
My Question
I'm just wondering, is there a way to do this using ONLY tensorflow or numpy methods? The only solution I could come up with is writing my own function that iterates through every row and gathers indices for all columns in that row. I have not had runtime issues yet but I would much rather utilize built-in tensorflow or numpy methods when possible. I've tried tf.gather before too, but I don't know if this particular case is possible with any combination of tf.gather and tf.gather_nd. If anyone has a suggestion, I would greatly appreciate it.
CodePudding user response:
You can use gather_nd()
for this. It can look a bit tricky to get this working. Let me try to explain this with shapes.
We got test1 -> [2, 3]
and test_ind_col_ind -> [2, 5]
. test_ind_col_ind
has only column indices, but you also need row indices to use gather_nd()
. To use gather_nd()
with a [2,3]
tensor, we need to create a test_ind -> [2, 5, 2]
sized tensor. The inner most dimension of this new test_ind
correspond to individual indices you want to index from test1
. Here we have the inner most dimension = 2
in the format (<row index>, <col index>)
. In other words, looking at the shape of test_ind
,
[ 2 , 5 , 2 ]
| |
V |
(2,5) | <- The size of the final tensor
V
(2,) <- The full index to a scalar in your input tensor
import tensorflow as tf
test1 = tf.round(5*tf.random.uniform(shape=(2,3)))
print(test1)
test_ind_col_ind = tf.constant([[0,1,0,0,1],
[0,1,1,1,0]], dtype=tf.int64)[:, :, tf.newaxis]
test_ind_row_ind = tf.repeat(tf.range(2, dtype=tf.int64)[:, tf.newaxis, tf.newaxis], 5, axis=1)
test_ind = tf.concat([test_ind_format, test_ind], axis=-1)
res = tf.gather_nd(indices=test_ind, params=test1)