Home > Software design >  Error with tensorflow map_fn. Unable to specify output signature
Error with tensorflow map_fn. Unable to specify output signature

Time:12-17

I am trying to use tensorflow's tf.map_fn to map a ragged tensor but I am getting an error that I can't fix. Here is some minimal code that demonstrates the error:

import tensorflow as tf

X = tf.ragged.constant([[0,1,2], [0,1]])
def outer_product(x):
  return x[...,None]*x[None,...]
tf.map_fn(outer_product, X)

My desired output is:

tf.ragged.constant([
 [[0, 0, 0],
  [0, 1, 2],
  [0, 2, 4]],
 [[0, 0],
  [0, 1]]
])

The error I am getting is:

"InvalidArgumentError: All flat_values must have compatible shapes. Shape at index 0: [3]. Shape at index 1: [2]. If you are using tf.map_fn, then you may need to specify an explicit fn_output_signature with appropriate ragged_rank, and/or convert output tensors to RaggedTensors. [Op:RaggedTensorFromVariant]"

I realize I need to specify fn_output_signature but despite experimentation, I cannot figure out what it should be.

EDIT: I cleaned up AloneTogether's excellent answer a little bit. His answer uses the tf.ragged.stack function to convert the tensors to ragged tensors which tf.map_fn needs for some reason

X = tf.ragged.constant([
                        [0,1,2], 
                        [0,1]
                       ])
def outer_product(x):
  t = x[...,None] * x[None,...]
  return tf.ragged.stack(t)


y = tf.map_fn(outer_product, X, fn_output_signature=tf.RaggedTensorSpec(shape=[1, None, None],
                                                                    dtype=tf.type_spec_from_value(X).dtype,
                                                                    ragged_rank=2))

print(y.shape) # == (2, 1, None , None)
y = tf.squeeze(y, 1)
tf.print(y.shape) # == (2, None , None)

CodePudding user response:

Ragged tensors are sometimes really tricky. Here is one option you can try out:

import tensorflow as tf

X = tf.ragged.constant([
                        [0,1,2], 
                        [0,1]
                       ])
def outer_product(x):
  t = x[...,None] * x[None,...]
  return tf.ragged.stack(t)


y = tf.map_fn(outer_product, X, fn_output_signature=tf.RaggedTensorSpec(shape=[1, None, None],
                                                                    dtype=tf.type_spec_from_value(X).dtype,
                                                                    ragged_rank=2,
                                                                    row_splits_dtype=tf.type_spec_from_value(X).row_splits_dtype))
tf.print(y)
y = tf.concat([y[0, :], y[1, :]], axis=0) # Remove additional dimension from Ragged Tensor
tf.print(y)
[
 [
  [
   [0, 0, 0], 
   [0, 1, 2], 
   [0, 2, 4]
  ]
 ], 
 [
  [
   [0, 0], 
   [0, 1]
  ]
 ]
]

And after removing the additional dimension with tf.concat:

[
 [
  [0, 0, 0], 
  [0, 1, 2], 
  [0, 2, 4]
 ], 
 [
  [0, 0], 
  [0, 1]
 ]
]
  • Related