Home > other >  How to clear the entire network structure in TensorFlow
How to clear the entire network structure in TensorFlow

Time:10-07

I would like to clear the memory / network after every time I am done with the training. I used the alternatives proposed online, but it seems like they are not working if I am correctly interpreting my results. I use tf.compat.v1.reset_default_graph() and tf.keras.backend.clear_session() since they are mostly recommended online.

import numpy as np
import random
import tensorflow as tf
from tensorflow import keras 
from tensorflow.python.keras import backend as K


upper_limit = 2
lower_limit = -2

training_input=  np.random.random ([100,5])*(upper_limit - lower_limit)   lower_limit
training_output = np.random.random ([100,1]) *10*(upper_limit - lower_limit)   lower_limit


model = tf.keras.Sequential([  
    tf.keras.layers.Flatten(input_shape=(5,)),
    tf.keras.layers.Dense(12, activation='relu'),
    tf.keras.layers.Dense(1) 
])   
                               
model.compile(loss="mse",optimizer = tf.keras.optimizers.Adam(learning_rate=0.01))

for layer in model.layers:
    print("layer weights before fitting: ",layer.get_weights(),"\n") # weights
    
model.fit(training_input, training_output, epochs=5, batch_size=100,verbose=0)
    

for layer in model.layers:
      print("layer weights after fitting: ",layer.get_weights(),"\n") # weights
print("\n") 

tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session()

print("after  clear","\n")
for layer in model.layers:
      print(layer.get_weights(),"\n") # weights  

When I print the layer weights after attempting to clear the network, I get the same weight values as before cleaning the session.

CodePudding user response:

I think what are you looking is reset the weights of you model, and that is not really related to the session or the graph (with some exceptions).

The reset of the weights is currently a debated topic you can find how to do it in most of the cases here but as you can see, today nobody is planning to implement this function

for easy access I post the current proposition below

def reset_weights(model):
  for layer in model.layers:
      if isinstance(layer, tf.keras.Model): #if you're using a model as a layer
          reset_weights(layer) #apply function recursively
          continue

      #where are the initializers?
      if hasattr(layer, 'cell'):
          init_container = layer.cell
      else:
          init_container = layer

      for key, initializer in init_container.__dict__.items():
          if "initializer" not in key: #is this item an initializer?
                continue #if no, skip it

          # find the corresponding variable, like the kernel or the bias
          if key == 'recurrent_initializer': #special case check
              var = getattr(init_container, 'recurrent_kernel')
          else:
              var = getattr(init_container, key.replace("_initializer", ""))

          var.assign(initializer(var.shape, var.dtype))

remember that if you are not defining a seed, the weigths will be differents each time you call reset

  • Related