I would like to train a keras model (say a simple FFNN) using the model.fit() method and not doing it 'by hand' (i.e. by using the gradient.tape method explained for example here). However, the loss function I need to use is quite elaborated and cannot be computed on randomly generated batches of data. As a result, I need to train the model using batches of data computed 'by hand' (i.e. the data that goes into each batch needs to have certain properties and cannot be randomly assigned).
Can I pass somehow pre-computed batches to the fit() method?
CodePudding user response:
One solution consists in sub-classing the Tensorflow Sequence. You can create your own batch for a given index using the __getitem__
method.
class MySequence(tf.keras.utils.Sequence):
def __init__(self, x_batch, y_batch) -> None:
super().__init__()
self.x_batch = x_batch # ordered list of batches
self.y_batch = y_batch # idem
self.leny = len(y_batch)
def __len__(self):
return self.leny
def __getitem__(self, idx):
x = self.x_batch[idx]
y = self.y_batch[idx]
return x, y
You can pass of an instance of this Sequence sub-class to the Model fit
method.
Also set shuffle=False
in the Model fit
arguments.