Home > Software design >  How to train a new custom layer with already trained layer using Subclass Modelling?
How to train a new custom layer with already trained layer using Subclass Modelling?

Time:12-10

The model MIMOSystem consist of many custom layers out of which only one layer is trainable, that is NeuralReceiver as shown below. After training the MIMOSystem model, I save the weights of it (weights belong to NeuralReceiver layer only). Now I want to load the weights onto second model MIMOSystem2 since it also has NeuralReceiver layer and freeze it. And train the MIMOSystem2 with already trained NeuralReceiver layer and new trainable layer NN_decoder. How do I load weights to NeuralReceiver() layer of MIMOSystem2 model and freeze it?

class MIMOSystem(Model): # Inherits from Keras Model

    def __init__(self, training):

        super(MIMOSystem, self).__init__()
               
        self.training = training
        self.constellation = Constellation("qam", num_bits_per_symbol)
        self.mapper = Mapper(constellation=self.constellation)
        self.demapper = Demapper("app",constellation=self.constellation)
        self.binary_source = BinarySource()
        self.channel = ApplyFlatFadingChannel(add_awgn=True)
        self.neural_receiver = NeuralReceiver() # the only trainable layer
        self.encoder = encoder = LDPC5GEncoder(k, n) 
        self.decoder = LDPC5GDecoder(encoder, hard_out=True)
        self.bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)
        self.acc = tf.keras.metrics.BinaryAccuracy()
   
    @tf.function
    def __call__(self, batch_size, ebno_db):

        if self.training:
            coderate = 1.0
            codewords = self.binary_source([batch_size, num_tx_ant, k])
        else:
            coderate = k/n
            bits = self.binary_source([batch_size, num_tx_ant, k])
            codewords = self.encoder(bits)
        
        x = self.mapper(codewords)
        no = ebnodb2no(ebno_db,num_bits_per_symbol,coderate)
        channel_shape = [tf.shape(x)[0], num_rx_ant, num_tx_ant]
        h = complex_normal(channel_shape)        
        y = self.channel([x, h, no])

        x_hat, no_eff = self.neural_receiver(y,h) # custom trainable layer
    
        llr = self.demapper([x_hat, no_eff])
        
        if self.training:
            bits_hat = tf.nn.sigmoid(llr)
            loss = self.bce(codewords, bits_hat)
            acc = self.acc(codewords, bits_hat)
            return loss, acc
        else:
            bits_hat = self.decoder(llr)                       
            return bits, bits_hat 

The trainable layer NeuralReceiver() consist of few sublayers, only two mentioned to give an idea.

class NeuralReceiver(Layer):
    def __init__(self):
        
        super().__init__()
        
        self.relu_layer = relu_layer()
        self.sign_layer = sign_layer() 
       
    def __call__(self, y_, H_):

        return x_hat, no_eff

MIMOSystem2 would have a freezed layer NeuralReceiver and trainable layer NN_decoder

class MIMOSystem2(Model): # Inherits from Keras Model

    def __init__(self, training):

        super(MIMOSystem2, self).__init__()
               
        self.training = training
        self.constellation = Constellation("qam", num_bits_per_symbol)
        self.mapper = Mapper(constellation=self.constellation)
        self.demapper = Demapper("app",constellation=self.constellation)
        self.binary_source = BinarySource()
        self.channel = ApplyFlatFadingChannel(add_awgn=True)
        self.neural_receiver = NeuralReceiver() # the frozen layer
        self.encoder = encoder = LDPC5GEncoder(k, n) 

        self.NN_decoder = NN_decoder() # new trainable layer
        self.bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)
        self.acc = tf.keras.metrics.BinaryAccuracy()
   
    @tf.function
    def __call__(self, batch_size, ebno_db):

        coderate = k/n
        bits = self.binary_source([batch_size, num_tx_ant, k])
        codewords = self.encoder(bits)
        
        x = self.mapper(codewords)
        no = ebnodb2no(ebno_db,num_bits_per_symbol,coderate)
        channel_shape = [tf.shape(x)[0], num_rx_ant, num_tx_ant]
        h = complex_normal(channel_shape)        
        y = self.channel([x, h, no])

        x_hat, no_eff = self.neural_receiver(y,h) # already trainable layer to be frozen 
    
        llr = self.demapper([x_hat, no_eff])

        bits_hat = self.NN_decoder(llr) # new trainable layer 
        
        if self.training:
            loss = self.bce(codewords, bits_hat)
            acc = self.acc(codewords, bits_hat)
            return loss, acc

        else:                     
            return bits, bits_hat 

CodePudding user response:

#If it's the last layer then simply put the "-1" but if you don't know then write the name of the layer then
for layer in model1.layers[-1].submodules:
    layer.trainable = False

#Now append your model, after which node you wanna append your node mention that, I am appending after the last node, So I wrote -1.
x= model1.layers[-1](_input)
x = tf.keras.layers.Dense(...)(x)
...
...
...
model = tf.keras.Model(inputs, x) 

CodePudding user response:

for layer in model1.layers[-1].submodules: layer.trainable = False

#Now append your model, after which node you wanna append your node mention that, I am appending after the last node, So I wrote -1. x= model1.layers-1 x = tf.keras.layers.Dense(...)(x) ... ... ... model = tf.keras.Model(inputs, x)

  • Related