I need to implement a layer in Tensorflow for a dataset of size N where each sample has a set of M independent features (each feature is represented by a tensor of dimension L). I want to train M dense layers in parallel, then concatenate the outputted tensors.
I could implement a layer using for loop as below:
class MyParallelDenseLayer(tf.keras.layers.Layer):
def __init__(self, dense_kwargs, **kwargs):
super().__init__(**kwargs)
self.dense_kwargs = dense_kwargs
def build(self, input_shape):
self.N, self.M, self.L = input_shape
self.list_dense_layers = [tf.keras.layers.Dense(**self.dense_kwargs) for a_m in range(self.M)]
super().build(input_shape)
def call(self, inputs):
parallel_output = [self.list_dense_layers[i](inputs[:, i]) for i in range(self.M)]
return tf.keras.layers.Concatenate()(parallel_output)
But the for loop in the 'call' function makes my layer extremely slow. Is there a faster way to do this layer?
CodePudding user response:
This should be doable using einsum
. Expand this layer to your liking with activation functions and whatnot.
class ParallelDense(tf.keras.layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
def build(self, input_shape):
super().build(input_shape)
self.kernel = self.add_weight(shape=[input_shape[1], input_shape[2], self.units])
def call(self, inputs):
return tf.einsum("bml, mlk -> bmk", inputs, self.kernel)
Test it:
b = 16 # batch size
m = 200
l = 4 # no. of input features per m
k = 10 # no. of output features per m
layer = ParallelDense(k)
inp = tf.random.normal([b, m, l])
print(layer(inp).shape)
(16, 200, 10)