Home > Back-end >  Repeat specific elements of a tensor using Keras
Repeat specific elements of a tensor using Keras

Time:09-07

I am using a 4-D tensor of shape=(N, 2, 127, 52)

I used:

tf.keras.backend.repeat_elements(tensor, 2, axis=3)

This duplicate the last axis size from 52 to 104 by repeating each value:

shape=(N, 2, 127, 104)

Now I want to the same but only with the last 10 elements from the third axis having now:

shape=(N, 2, 127, 114)

I am also looking how to add an extra "column" by adding a zero vector in the middle of the last axis tensor resulting:

shape=(N, 2, 127, 115)

How can I do this?

CodePudding user response:

I think using tf.concat would be a simple way:

import tensorflow as tf

N = 2
tensor = tf.random.normal((N, 2, 127, 52))
tensor = tf.repeat(tensor, 2, axis=3)

# (N, 2, 127, 114)
tensor = tf.concat([tensor, tensor[..., tf.shape(tensor)[-1]-10:]], axis=-1)

# (N, 2, 127, 115)
middle = tf.shape(tensor)[-1]//2
tensor = tf.concat([tensor[..., :middle], tf.zeros((N, 2, 127, 1)), tensor[..., middle:]], axis=-1)

print(tensor.shape)
(2, 2, 127, 115)
  • Related