I have a numpy array of shape (100,4,30). This represents 100 samples of 4 samples of encodings of length 30. The 4 samples, per row, are related.
I want to get a TensorFlow dataset, batched, where related samples are in the same batch.
I'm trying to do:
first, use np.vsplit
to get a list of length 100, where each element in the list is a list of the 4 related samples.
Now if I call tf.data.Dataset.from_tensor_slices(...).batch(1)
on this list of lists, I get a batch that contains a tensor of shape (4,1,30).
I want this batch to contain 4 tensors of shape (1,30).
How can I achieve this?
CodePudding user response:
I may have missunderstood you, but if you just leave out the "vsplit":
data = np.zeros((100, 4, 30))
data_ds = tf.data.Dataset.from_tensor_slices(data).batch(1)
for element in data_ds.take(1):
print(element.shape)
you will get:
(1, 4, 30)
(so one batch contains all 4 related encodings).
If you really want the dimensions inside a batch to be 4 times (1, 30) you can do:
data = np.expand_dims(data, axis=2)
before dataset creation.
EDIT:
I think I just understood your question. You want every batch to have 4 elements and those are the related encodings? You can achieve this by:
data = np.swapaxes(data, 0, 1)
data = np.reshape(data, (100*4, -1))
data_ds = tf.data.Dataset.from_tensor_slices(data).batch(4)