Home > Software engineering >  Passing pre-computed batches to Tensorflow fit() method
Passing pre-computed batches to Tensorflow fit() method

Time:06-23

I would like to train a keras model (say a simple FFNN) using the model.fit() method and not doing it 'by hand' (i.e. by using the gradient.tape method explained for example here). However, the loss function I need to use is quite elaborated and cannot be computed on randomly generated batches of data. As a result, I need to train the model using batches of data computed 'by hand' (i.e. the data that goes into each batch needs to have certain properties and cannot be randomly assigned).

Can I pass somehow pre-computed batches to the fit() method?

CodePudding user response:

One solution consists in sub-classing the Tensorflow Sequence. You can create your own batch for a given index using the __getitem__ method.

class MySequence(tf.keras.utils.Sequence):
    def __init__(self, x_batch, y_batch) -> None:
      super().__init__()
      self.x_batch = x_batch   # ordered list of batches
      self.y_batch = y_batch   # idem
      self.leny = len(y_batch)

    def __len__(self):
      return self.leny

    def __getitem__(self, idx):
      x = self.x_batch[idx]
      y = self.y_batch[idx]
      return x, y

You can pass of an instance of this Sequence sub-class to the Model fit method.
Also set shuffle=False in the Model fit arguments.

  • Related