I'm trying to take variable length tensors and split them up into tensors of length 4, discarding any extra elements (if the length is not divisible by four).
I've therefore written the following function:
def batches_of_four(tokens):
token_length = tokens.shape[0]
splits = token_length // 4
tokens = tokens[0 : splits * 4]
return tf.split(tokens, num_or_size_splits=splits)
dataset = tf.data.Dataset.from_tensor_slices(
tf.ragged.constant([[1, 2, 3, 4, 5], [4, 5, 6, 7]]))
print(batches_of_four(next(iter(dataset))))
This produces the output [<tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4], dtype=int32)>]
, as expected.
If I now run the same function using Dataset.map
:
for item in dataset.map(batches_of_four):
print(item)
I instead get the following error
File "<ipython-input-173-a09c55117ea2>", line 5, in batches_of_four *
splits = token_length // 4
TypeError: unsupported operand type(s) for //: 'NoneType' and 'int'
I see that this is because token_length
is None
, but I don't understand why. I assume this has something to do with graph vs eager execution, but the function works if I call it outside of .map
even if I annotate it with @tf.function
.
Why is the behavior different inside .map
? (Also: is there any better way of writing the batches_of_four
function?)
CodePudding user response:
You should use tf.shape
to get the dynamic shape of a tensor in graph
mode:
token_length = tf.shape(tokens)[0]
And another problem you have is using a scalar tensor as the number of splits in graph
mode. That won't work either.
Try this:
import tensorflow as tf
def body(i, m, n):
n = n.write(n.size(), m[i:i chunk_size])
return tf.add(i,chunk_size), m, n
def split_data(data, chunk_size):
length = tf.shape(data)[0]
x = data[:(length // chunk_size) * chunk_size]
ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
i0 = tf.constant(0)
c = lambda i, m, n: tf.less(i, tf.shape(x)[0] - 1)
_, _, out = tf.while_loop(c, body, loop_vars=[i0, x, ta])
return out.stack()
chunk_size = 4
dataset = tf.data.Dataset.from_tensor_slices(
tf.ragged.constant([[1, 2, 3, 4, 5], [4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7, 8, 9]])).map(lambda x: split_data(x, 4)).flat_map(tf.data.Dataset.from_tensor_slices)
for item in dataset:
print(item)
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
tf.Tensor([5 6 7 8], shape=(4,), dtype=int32)
And see my other answer here.