Home > Back-end >  How to save memory when training a model in Keras where X is f(Y)?
How to save memory when training a model in Keras where X is f(Y)?

Time:12-19

I have a problem where my input X can be defined as a function applied to output Y. I imagine I can save memory during training by only having Y in memory and creating X per example during training.

If this is my training code:

history = model.fit(
    x_train,
    y_train,
    batch_size=64,
    epochs=2,
    # We pass some validation for
    # monitoring validation loss and metrics
    # at the end of each epoch
    validation_data=(x_val, y_val),
)

How can I define x_train as f(y_train) per example? Thanks!

CodePudding user response:

You can do that by writing your own training loop as explained here -> https://keras.io/guides/writing_a_training_loop_from_scratch/#using-the-gradienttape-a-first-endtoend-example

The part you're interested is this


epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, y_batch_train in enumerate(train_dataset):
        x_batch_train = F(y_batch_train) #Here your function, rest of the code is like the tutorial
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step   1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

CodePudding user response:

You could create a custom dataset from a generator using from_generator. Something like this:

def create_ds(y_data):
    def generator():
        for y in y_data:
            yield f(x), y

        ds = tf.data.Dataset.from_generator(generator=generator, output_signature=...)
        ds = ds.batch(64).prefetch(buffer_size=tf.data.AUTOTUNE)
    return ds

ds = create_ds(y_train)
val_ds = create_ds(y_val)

history = model.fit(ds, epochs=2, validation_data=val_ds)
  • Related