Using TF 2.11.0 with a GPU on Colab.
I am getting an increased amount of system memory used per batch when let the fit() method run (This code only checks per epoch). This is a very basic CycleGAN class:
import psutil
import gc
from keras.callbacks import Callback
class MemoryUsageCallback(Callback):
'''Monitor memory usage on epoch begin and end, collect garbage'''
def on_epoch_begin(self,epoch,logs=None):
print('**Epoch {}**'.format(epoch))
print('Memory usage on epoch begin: {}'.format(psutil.Process(os.getpid()).memory_info().rss))
def on_epoch_end(self,epoch,logs=None):
gc.collect()
tf.keras.backend.clear_session()
class CycleGan(keras.Model):
def __init__(
self,
monet_generator,
photo_generator,
monet_discriminator,
photo_discriminator,
lambda_cycle=10,
):
super(CycleGan, self).__init__()
self.m_gen = monet_generator
self.p_gen = photo_generator
self.m_disc = monet_discriminator
self.p_disc = photo_discriminator
self.lambda_cycle = lambda_cycle
def compile(
self,
m_gen_optimizer,
p_gen_optimizer,
m_disc_optimizer,
p_disc_optimizer,
gen_loss_fn,
disc_loss_fn,
cycle_loss_fn,
):
super(CycleGan, self).compile()
self.m_gen_optimizer = m_gen_optimizer
self.p_gen_optimizer = p_gen_optimizer
self.m_disc_optimizer = m_disc_optimizer
self.p_disc_optimizer = p_disc_optimizer
self.gen_loss_fn = gen_loss_fn
self.disc_loss_fn = disc_loss_fn
self.cycle_loss_fn = cycle_loss_fn
def train_step(self, batch_data):
real_monet, real_photo = batch_data
with tf.GradientTape(persistent=True) as tape:
# Pass the images through the gens
fake_monet = self.m_gen(real_photo, training=True)
fake_monet_resized = tf.image.resize(fake_monet, [256, 256])
cycled_photo = self.p_gen(fake_monet_resized, training=True)
fake_photo = self.p_gen(real_monet, training=True)
fake_photo_resized = tf.image.resize(fake_photo , [256, 256])
cycled_monet = self.m_gen(fake_photo_resized, training=True)
# resize original images for disc
real_monet = tf.image.resize(real_monet, [320, 320])
real_photo = tf.image.resize(real_photo, [320, 320])
# Calculate discriminators answers
disc_real_monet = tf.reduce_mean(self.m_disc(real_monet, training=True), axis=[1,2])
disc_real_photo = tf.reduce_mean(self.p_disc(real_photo, training=True), axis=[1,2])
disc_fake_monet = tf.reduce_mean(self.m_disc(fake_monet, training=True), axis=[1,2])
disc_fake_photo = tf.reduce_mean(self.p_disc(fake_photo, training=True), axis=[1,2])
# Calculate cycle loss
cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)
# evaluates total generator loss
total_monet_gen_loss = self.gen_loss_fn(disc_fake_monet) cycle_loss
total_photo_gen_loss = self.gen_loss_fn(disc_fake_photo) cycle_loss
# evaluates discriminator loss
monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)
# Calculate the gradients for generator and discriminator
monet_generator_gradients = tape.gradient(total_monet_gen_loss,
self.m_gen.trainable_variables)
photo_generator_gradients = tape.gradient(total_photo_gen_loss,
self.p_gen.trainable_variables)
monet_discriminator_gradients = tape.gradient(monet_disc_loss,
self.m_disc.trainable_variables)
photo_discriminator_gradients = tape.gradient(photo_disc_loss,
self.p_disc.trainable_variables)
# Apply the gradients to the optimizer
self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
self.m_gen.trainable_variables))
self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
self.p_gen.trainable_variables))
self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
self.m_disc.trainable_variables))
self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
self.p_disc.trainable_variables))
return {
"monet_gen_loss": total_monet_gen_loss,
"photo_gen_loss": total_photo_gen_loss,
"monet_disc_loss": monet_disc_loss,
"photo_disc_loss": photo_disc_loss
}
And I train the model like this:
cycle_gan_model = CycleGan(
monet_generator, photo_generator, monet_discriminator, photo_discriminator)
cycle_gan_model.compile(
m_gen_optimizer = monet_generator_optimizer,
p_gen_optimizer = photo_generator_optimizer,
m_disc_optimizer = monet_discriminator_optimizer,
p_disc_optimizer = photo_discriminator_optimizer,
gen_loss_fn = generator_loss,
disc_loss_fn = discriminator_loss,
cycle_loss_fn = calc_cycle_loss,
)
callbacks = [MemoryUsageCallback()]
zipped_dataset = tf.data.Dataset.zip((monet_dataset, photos_dataset))
cycle_gan_model.fit(
zipped_dataset,
epochs=20,
steps_per_epoch=150,
callbacks=callbacks
)
Note that the dataset I am using is one from a generators I defined earlier which I am not so sure work well >.<
The memory usage in the start of the epochs is constantly increasing:
3195469824->3974574080->4476375040->4954685440->...7341445120->...
it doesn't seem to stop.
I tried:
- Changing my activation functions from ReLU to LeakyReLU.
- Checking that run_eagerly=true
- Putting my activation functions in serperate layers
- Adding a custom callback that collects garbage and clears the Keras backend at the end of each epoch.
- Lowering the batch size (right now it's 2)
- Playing around with the layers of the models
- Changing the optimizers to ADAFactor as it is more memeory efficient
I expected the memory usage to stabalize after the first epoch was over, but instead of it kept increasing.
---UPDATE---
Changed the train_step to
def train_step(self, batch_data):
return {
"monet_gen_loss": 0,
"photo_gen_loss": 0,
"monet_disc_loss": 0,
"photo_disc_loss": 0
}
and yet I am still seeing the memory increase so I am suspecting it might be the datasets.
I saved the datasets on my machine as image files then used these methods to make the datasets:
def image_loader(folder_path, batch_size):
while True:
folder_path = Path(folder_path)
image_files = [file for file in folder_path.iterdir() if file.suffix == ".jpg"] # Shuffle the list of image files
random.shuffle(image_files)
for i in range(0, len(image_files), batch_size):
batch_files = image_files[i:i batch_size]
images = [np.array(Image.open(str(file))) for file in batch_files]
yield np.array(images)
BATCH_SIZE = 2
photos_loader = image_loader('/content/photos/',BATCH_SIZE)
def augment_element(image):
image = tf.convert_to_tensor(image)
# flips and rotations
image = tf.image.random_flip_left_right(image)
image = tf.image.rot90(image, k=random.randint(0,3))
# Crops the image randomly and resizes it back
size = tf.random.uniform(shape=[], minval=100, maxval=256, dtype=tf.int32)
image = tf.image.random_crop(image, size=[size, size, 3])
image = tf.image.resize(image, [256, 256])
return image.numpy()
def augmented_image_loader(folder_path, batch_size):
while True:
folder_path = Path(folder_path)
image_files = [file for file in folder_path.iterdir() if file.suffix == ".jpg"]
images = [np.array(Image.open(str(file))) for file in image_files]
while True:
random_images = random.choices(images, k=batch_size)
augmented_images = [augment_element(image) for image in random_images]
yield np.array(augmented_images)
augmented_monet_loader = augmented_image_loader('/content/monet/',BATCH_SIZE)
and finally
monet_dataset = tf.data.Dataset.from_generator(
lambda: augmented_image_loader('/content/monet/', BATCH_SIZE),
output_types=tf.float32,
output_shapes=(tf.TensorShape([BATCH_SIZE, None, None, 3]))
)
photos_dataset = tf.data.Dataset.from_generator(
lambda: image_loader('/content/photos/', BATCH_SIZE),
output_types=tf.float32,
output_shapes=(tf.TensorShape([BATCH_SIZE, None, None, 3]))
)
def normalize(image):
image = image - 127.5
image = image / 127.5
return image
monet_dataset = monet_dataset.cache().map(normalize)
photos_dataset = photos_dataset.cache().map(normalize)
but then again using the following code doesn't seem to increase my memory usage (except for the first iteration):
for i in range(1000):
data = next(iter(zipped_dataset))
It's my first time working with tf datasets (part of my uni's assignment) so I am a complete rookie in that department, any help would be appriciated! :D
CodePudding user response:
I noticed you wrote:
monet_dataset = monet_dataset.cache().map(normalize)
photos_dataset = photos_dataset.cache().map(normalize)
It's important to note that the cache() method loads the entire dataset into memory, so it's not suitable for very large datasets. It is best used when the dataset can fit into memory, when the elements are expensive to generate and you expect to iterate over the same dataset multiple times.
Replace it with
monet_dataset = monet_dataset.map(normalize)
photos_dataset = photos_dataset.map(normalize)
And it should solve the memory issue.