Home > Net >  Saving layer outputs in keras funtional API
Saving layer outputs in keras funtional API

Time:12-25

I'm implementing U-net with the keras functional api. One aspect of U-net is to have 'horizontal' context connections (similar to a residual). I create downsampling and upsampling layers with for loops. For example:

for filters in [32, 64, 128]:
        x = inverted_residual_block(x, expand=filters*2, squeeze=filters)
        x = inverted_residual_block(x, expand=filters*2, squeeze=filters)
        
        x = down_sampling_block(x,filters=filters)

can I collect intermediate inputs in a simple list like

horizontal_connection.append(x)

in my loop to call during upsampling? I can't find any notes in keras or TF about this method causing issues. I'm concerned that this will cause an error during graph generation. If anyone has experience or insight to share it would be appreciated!

CodePudding user response:

I think this can work only in eager execution because a list is a python construct. In graph execution, python constructs are evaluated just once and that is before the actual execution of tensorflow graph. If you need to create arrays in graph execution in loops, you can use tf.TensorArray (https://www.tensorflow.org/api_docs/python/tf/TensorArray).

Eager execution evaluates tensor operations together with python code; therefore, it's possible to use python lists and tensorfow's tensors together. Eager execution is good for debugging, but if you want it to run faster then you have to use graph execution.

CodePudding user response:

I've gone through and implemented these uses of arrays and I have seen no major drop in performance. My understanding of the Functional API is incomplete, but I believe this makes sense as you are just referencing different layers and not adding issues to the graph. As the graph runs through the layers of your model, and shouldn't be broken by how you construct the models. In addition, even if that assumption is wrong, AutoGraph takes care of converting base python code into graph code. E.g., converting while loops into tf.while_loop().

Thank you for commenting on my question!

  • Related