I would like to clear the memory / network after every time I am done with the training. I used the alternatives proposed online, but it seems like they are not working if I am correctly interpreting my results. I use tf.compat.v1.reset_default_graph()
and tf.keras.backend.clear_session()
since they are mostly recommended online.
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras import backend as K
upper_limit = 2
lower_limit = -2
training_input= np.random.random ([100,5])*(upper_limit - lower_limit) lower_limit
training_output = np.random.random ([100,1]) *10*(upper_limit - lower_limit) lower_limit
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(5,)),
tf.keras.layers.Dense(12, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(loss="mse",optimizer = tf.keras.optimizers.Adam(learning_rate=0.01))
for layer in model.layers:
print("layer weights before fitting: ",layer.get_weights(),"\n") # weights
model.fit(training_input, training_output, epochs=5, batch_size=100,verbose=0)
for layer in model.layers:
print("layer weights after fitting: ",layer.get_weights(),"\n") # weights
print("\n")
tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session()
print("after clear","\n")
for layer in model.layers:
print(layer.get_weights(),"\n") # weights
When I print the layer weights after attempting to clear the network, I get the same weight values as before cleaning the session.
CodePudding user response:
I think what are you looking is reset the weights of you model, and that is not really related to the session or the graph (with some exceptions).
The reset of the weights is currently a debated topic you can find how to do it in most of the cases here but as you can see, today nobody is planning to implement this function
for easy access I post the current proposition below
def reset_weights(model):
for layer in model.layers:
if isinstance(layer, tf.keras.Model): #if you're using a model as a layer
reset_weights(layer) #apply function recursively
continue
#where are the initializers?
if hasattr(layer, 'cell'):
init_container = layer.cell
else:
init_container = layer
for key, initializer in init_container.__dict__.items():
if "initializer" not in key: #is this item an initializer?
continue #if no, skip it
# find the corresponding variable, like the kernel or the bias
if key == 'recurrent_initializer': #special case check
var = getattr(init_container, 'recurrent_kernel')
else:
var = getattr(init_container, key.replace("_initializer", ""))
var.assign(initializer(var.shape, var.dtype))
remember that if you are not defining a seed, the weigths will be differents each time you call reset