Home > Enterprise >  Exploding tensor after using Dataset and .batch
Exploding tensor after using Dataset and .batch

Time:05-24

I have a numpy array of shape (100,4,30). This represents 100 samples of 4 samples of encodings of length 30. The 4 samples, per row, are related.

I want to get a TensorFlow dataset, batched, where related samples are in the same batch.

I'm trying to do:

first, use np.vsplit to get a list of length 100, where each element in the list is a list of the 4 related samples.

Now if I call tf.data.Dataset.from_tensor_slices(...).batch(1) on this list of lists, I get a batch that contains a tensor of shape (4,1,30).

I want this batch to contain 4 tensors of shape (1,30).

How can I achieve this?

CodePudding user response:

I may have missunderstood you, but if you just leave out the "vsplit":

data = np.zeros((100, 4, 30))
data_ds = tf.data.Dataset.from_tensor_slices(data).batch(1)
for element in data_ds.take(1):
    print(element.shape)

you will get:

(1, 4, 30)

(so one batch contains all 4 related encodings).

If you really want the dimensions inside a batch to be 4 times (1, 30) you can do:

data = np.expand_dims(data, axis=2)

before dataset creation.

EDIT:

I think I just understood your question. You want every batch to have 4 elements and those are the related encodings? You can achieve this by:

data = np.swapaxes(data, 0, 1)
data = np.reshape(data, (100*4, -1))
data_ds = tf.data.Dataset.from_tensor_slices(data).batch(4)
  • Related