Home > database >  Custom dense layer in Keras/TensorFlow with 2D input, 2D weight, and 2D bias?
Custom dense layer in Keras/TensorFlow with 2D input, 2D weight, and 2D bias?

Time:04-05

I am working with images and want to build custom layer for my model. I want to multiply each pixel by weight and add bias to it (x.w b). I know that flatten will work for this but I have additional tasks in calculating and I may need transpose for some of them as well. my questions, can I multiply 2D form of each inputs and wight and add two dimensional bias to it for my custom dense layer? I tried but shape only accepts one dimensional input and gives number of unit output shape=(input_dim, units). I want input_dim to be 2D for both Weight and Bias!

class Dense(layers.Layer):

def __init__(self, units):
    super(Dense, self).__init__()
    self.units = units
def build(self, input_shape):
    self.w = self.add_weight(
        name="w",
        shape=(input_shape[-1], self.units),
        initializer="random_normal",
        trainable=True,
    )
    self.b = self.add_weight(
        name="b", initializer="random_normal", trainable=True, shape=(input_shape[-1], self.units), 
    )
def call(self, inputs):
    return tf.matmul(inputs, self.w)   self.b

CodePudding user response:

IIUC, it depends on your desired output. You can try something like this:

import tensorflow as tf

class Dense2D(tf.keras.layers.Layer):

  def __init__(self, units):
      super(Dense2D, self).__init__()
      self.units = units
  def build(self, input_shape):
      self.w = self.add_weight(
          name="w",
          shape=(input_shape[-1], self.units),
          initializer="random_normal",
          trainable=True,
      )
      self.b = self.add_weight(
          name="b", initializer="random_normal", trainable=True, shape=(input_shape[1], input_shape[-1]), 
      )
  def call(self, inputs):
      return tf.matmul(inputs, self.w)   self.b

dense2d = Dense2D(units = 10)      
samples = 1
x = tf.random.normal((samples, 5, 10))
print(dense2d(x).shape)

# (1, 5, 10)
  • Related