I am building a model that applies a random shuffle to data along the first non batch axis, applies a series of Conv1Ds, then applies the inverse of the shuffle. Unfortunately the tf.gather
layer messes up the batch dimension None
, and i'm not sure why.
Below is an example of what happens.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
dim = 90
input_img = keras.Input(shape=(dim, 4))
# Get random shuffle order
order = layers.Lambda(lambda x: tf.random.shuffle(tf.range(x)))(dim)
# Apply shuffle
tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))(input_img, order)
model = keras.models.Model(
inputs=[input_img],
outputs=tensor,
)
Here the summary is as follows:
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 90, 4)] 0
_________________________________________________________________
lambda_51 (Lambda) (90, 90, 4) 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
Whereas I want the output shape of lambda_51
to be (None, 90, 4)
.
CodePudding user response:
Try to wrap input_img
and order
into a list when you pass them to tensor
layer.
In this way tensor
layer becomes:
tensor = layers.Lambda(lambda x: tf.gather(x[0], tf.cast(x[1], tf.int32), axis=1,))([input_img, order])
and your summary:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 90, 4)] 0
_________________________________________________________________
lambda_3 (Lambda) (None, 90, 4) 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0