I have a tf.keras
model that takes as input a tensor of shape (batch_size, )
and outputs another tensor of the same shape. The result at index i
does not depend on any of the inputs at index j != i
.
I would like to apply this model on tensors of any shape (dim1, dim2, ..., dimn)
. In theory this should be possible, but in practice tensorflow refuses to process anything with an input shape of more than 1 dimension. What would be the most elegant work-around to bypass this? I've looked at tf.map_fn
but this might get complicated when used recursively. Any simpler methods I'm overlooking?
CodePudding user response:
In the end I solved it like this:
def apply_model(X: tf.Tensor, my_model: tf.keras.Model) -> tf.Tensor:
"""
Apply a tf.keras.Model to a tensor of unknown dimensions.
Args:
X (tf.Tensor): The tensor containing the input.
my_model (tf.keras.Model): The model you want to apply.
Returns:
tf.Tensor: A tensor of the same shape as X, where all values are
a prediction by the model.
"""
if len(X.shape) > 1:
result = tf.stack(
[apply_model(x) for x in tf.unstack(X, axis=-1)],
axis=-1,
)
else:
result = my_model(X)
return result
Of course, you can generalize this to a case where the model takes an input of more than 1 dimension.