Basically, i have a list of indices generated by tf.random.uniform
and a tensor called tensor_big
, now i need to create a new tensor tensor_small
where inside are all the elements of the main tensor where the second coordinate is inside the indices
list.
Example:
Indices = [1, ......]
Then i need to create a new tensor with the weights at position [0,1], [1,1], [2,1]
etc for each indices.
import tensorflow as tf
if __name__ == '__main__':
tensor_big = tf.random.uniform(
(3136,512), minval=0, maxval=None, dtype=tf.dtypes.float32, seed=None, name=None
)
indices = tf.random.uniform(shape=[410, ], minval=0, maxval=512, dtype=tf.dtypes.int32, seed=None, name=None)
for weight in tensor_big:
print(weight[1])
tensor_small = tf.reshape(tf.gather(tensor_big, WHERE_SECOND_COORDINATE_INSIDE_INDICES), (3136,410))
print(tensor_small)
CodePudding user response:
You can use tf.gather
with the argument axis=1
to select the columns:
tensor_small = tf.gather(tensor_big, indices, axis=1)