First i splitted the original tensor and then after some operations i want to combine the tensor to the original shape and the original tensor, before splitting it. I'm not sure i can just use the old tensor with graph mode in tensorflow.
Each dimension of the four dimension of tensor_a has at least a size of 2.
tensor_a = tf.split(tensor_c, split_into, axis=1)) # creating additional dimension
# some operations
tensor_a = tf.convert_to_tensor(tensor_a)
first, second, third, fourth = tensor_a.shape
tensor_b = tf.reshape(tensor_a, (second, first * third, fourth))
CodePudding user response:
The code below is self-contained, and shows a couple of ways of reconstructing the original tensor after splitting. I think the second approach, using tf.concat() instead of tf_convert_to_tensor() is the neatest. I'm hoping the code is self-explanatory.
import tensorflow as tf
# Construct a test tensor, to be split and then reconstructed
tensor_c = tf.reshape(tf.constant([i for i in range(24)]), [2,6,2])
print("tensor_c")
print(tensor_c.numpy())
# Split it, as the question does
list_of_tensor_a = tf.split(tensor_c, 3, axis=1)
print("\nlist_of_tensor_a")
print([t.numpy() for t in list_of_tensor_a])
# Create a tensor shape (3, 2, 2, 2), as the question does
# This changes the original ordering of tensor_c. It was split on axis 1,
# and is now reassembled by creating a new axis 0
tensor_a = tf.convert_to_tensor(list_of_tensor_a)
print("\ntensor_a")
print(tensor_a.shape)
print(tensor_a.numpy())
# Reshape as in the question.
# Does not reconstruct tensor_c, since the ordering has been changed
first, second, third, fourth = tensor_a.shape
tensor_b = tf.reshape(tensor_a, (second, first * third, fourth))
print("\ntensor_b - incorrect reconstruction of tensor_c")
print(tensor_b.numpy())
# Correct reconstruction, first approach.
# Use tf.transpose() to restore the original order
tensor_b2 = tf.reshape(tf.transpose(tensor_a,[1,0,2,3]), (second, first * third, fourth))
print("\ntensor_b2 - correct reconstruction of tensor_c")
print(tensor_b2.numpy())
# Correct reconstruction, second (and neater) approach.
# Use tf.concat() instead of tf.convert_to_tensor()
tensor_b3 = tf.concat(list_of_tensor_a, axis=1)
print("\ntensor_b3 - correct reconstruction of tensor_c")
print(tensor_b3.numpy())