Home > Back-end >  Apply tf.keras model to tensor of variable shape
Apply tf.keras model to tensor of variable shape

Time:11-24

I have a tf.keras model that takes as input a tensor of shape (batch_size, ) and outputs another tensor of the same shape. The result at index i does not depend on any of the inputs at index j != i.

I would like to apply this model on tensors of any shape (dim1, dim2, ..., dimn). In theory this should be possible, but in practice tensorflow refuses to process anything with an input shape of more than 1 dimension. What would be the most elegant work-around to bypass this? I've looked at tf.map_fn but this might get complicated when used recursively. Any simpler methods I'm overlooking?

CodePudding user response:

In the end I solved it like this:

def apply_model(X: tf.Tensor, my_model: tf.keras.Model) -> tf.Tensor:
        """
        Apply a tf.keras.Model to a tensor of unknown dimensions.

        Args:
            X (tf.Tensor): The tensor containing the input.
            my_model (tf.keras.Model): The model you want to apply.

        Returns:
            tf.Tensor: A tensor of the same shape as X, where all values are
                a prediction by the model.
        """
        if len(X.shape) > 1:
            result = tf.stack(
                [apply_model(x) for x in tf.unstack(X, axis=-1)],
                axis=-1,
            )
        else:
            result = my_model(X)

        return result

Of course, you can generalize this to a case where the model takes an input of more than 1 dimension.

  • Related