Home > Enterprise >  How to use correlation in keras lambda layer correctly?
How to use correlation in keras lambda layer correctly?

Time:05-13

What I want to do is to calculate correlations within the model and use the correlation results as input for a next layer. I could compute the correlations beforehand, however I have a lot of input features and calculating all feature inter-correlations is not feasible. My idea is to reduce the features down to a manageable size and then compute their correlations. Here is a minimal example where I stumbled upon a problem:

from tensorflow import keras
import tensorflow_probability as tfp

def the_corr(x):
  return tfp.stats.correlation(x, sample_axis = 1)

input = keras.Input(shape=(100,3000,))
x = keras.layers.Conv1D(filters=64, kernel_size=1,activation='relu') (input)
x = keras.layers.Lambda(the_corr, output_shape=(64,64,)) (x)
#x = keras.layers.Dense(3) (x)
model = keras.Model(input, x)
model.summary()

However, this is the summary result:

_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_9 (InputLayer)        [(None, 100, 3000)]       0         
                                                                 
 conv1d_3 (Conv1D)           (None, 100, 64)           192064    
                                                                 
 lambda_8 (Lambda)           (None, None, None)        0         
                                                                 
=================================================================
Total params: 192,064
Trainable params: 192,064
Non-trainable params: 0
_________________________________________________________________

The lambda layer produces not the correct output shape, and completely ignores the option output_shape=(64,64,). So obviously, if the commented line is brought back in, the following dense layer will throw an error:

ValueError: The last dimension of the inputs to a Dense layer should be defined. Found None. Full input shape received: (None, None, None)

I can also remove the sample_axis=1 option in tfp.stats.correlation(), however, then the batch axis (None) is thrown away:

_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_12 (InputLayer)       [(None, 100, 3000)]       0         
                                                                 
 conv1d_6 (Conv1D)           (None, 100, 64)           192064    
                                                                 
 lambda_11 (Lambda)          (100, 64, 64)             0         
                                                                 
 dense_5 (Dense)             (100, 64, 3)              195       
                                                                 
=================================================================
Total params: 192,259
Trainable params: 192,259
Non-trainable params: 0
_________________________________________________________________

This is also not what I want, as the batch samples are independent and should not be brought together.:

What am I doing wrong? Is this even possible?

CodePudding user response:

You can try to set keepdims=True in tfp.stats.corr:

def the_corr(x):
    x = tfp.stats.correlation(x,
                              sample_axis=1,
                              keepdims=True)
    
    # Keepdims will give an extra dim. 
    x = tf.squeeze(x, axis = 1)
    return x

input = keras.Input(shape=(100,3000,))
x = keras.layers.Conv1D(filters=64, kernel_size=1,activation='relu') (input)
x = keras.layers.Lambda(the_corr)(x)
x = keras.layers.Dense(3)(x)
model = keras.Model(input, x)
model.summary()

Summary:

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 100, 3000)]       0         
                                                                 
 conv1d (Conv1D)             (None, 100, 64)           192064    
                                                                 
 lambda (Lambda)             (None, 64, 64)            0         
                                                                 
 dense (Dense)               (None, 64, 3)             195       
                                                                 
=================================================================
Total params: 192,259
Trainable params: 192,259
Non-trainable params: 0
  • Related