The model MIMOSystem consist of many custom layers out of which only one layer is trainable, that is NeuralReceiver as shown below. After training the MIMOSystem model, I save the weights of it (weights belong to NeuralReceiver layer only). Now I want to load the weights onto second model MIMOSystem2 since it also has NeuralReceiver layer and freeze it. And train the MIMOSystem2 with already trained NeuralReceiver layer and new trainable layer NN_decoder. How do I load weights to NeuralReceiver() layer of MIMOSystem2 model and freeze it?
class MIMOSystem(Model): # Inherits from Keras Model
def __init__(self, training):
super(MIMOSystem, self).__init__()
self.training = training
self.constellation = Constellation("qam", num_bits_per_symbol)
self.mapper = Mapper(constellation=self.constellation)
self.demapper = Demapper("app",constellation=self.constellation)
self.binary_source = BinarySource()
self.channel = ApplyFlatFadingChannel(add_awgn=True)
self.neural_receiver = NeuralReceiver() # the only trainable layer
self.encoder = encoder = LDPC5GEncoder(k, n)
self.decoder = LDPC5GDecoder(encoder, hard_out=True)
self.bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)
self.acc = tf.keras.metrics.BinaryAccuracy()
@tf.function
def __call__(self, batch_size, ebno_db):
if self.training:
coderate = 1.0
codewords = self.binary_source([batch_size, num_tx_ant, k])
else:
coderate = k/n
bits = self.binary_source([batch_size, num_tx_ant, k])
codewords = self.encoder(bits)
x = self.mapper(codewords)
no = ebnodb2no(ebno_db,num_bits_per_symbol,coderate)
channel_shape = [tf.shape(x)[0], num_rx_ant, num_tx_ant]
h = complex_normal(channel_shape)
y = self.channel([x, h, no])
x_hat, no_eff = self.neural_receiver(y,h) # custom trainable layer
llr = self.demapper([x_hat, no_eff])
if self.training:
bits_hat = tf.nn.sigmoid(llr)
loss = self.bce(codewords, bits_hat)
acc = self.acc(codewords, bits_hat)
return loss, acc
else:
bits_hat = self.decoder(llr)
return bits, bits_hat
The trainable layer NeuralReceiver() consist of few sublayers, only two mentioned to give an idea.
class NeuralReceiver(Layer):
def __init__(self):
super().__init__()
self.relu_layer = relu_layer()
self.sign_layer = sign_layer()
def __call__(self, y_, H_):
return x_hat, no_eff
MIMOSystem2 would have a freezed layer NeuralReceiver and trainable layer NN_decoder
class MIMOSystem2(Model): # Inherits from Keras Model
def __init__(self, training):
super(MIMOSystem2, self).__init__()
self.training = training
self.constellation = Constellation("qam", num_bits_per_symbol)
self.mapper = Mapper(constellation=self.constellation)
self.demapper = Demapper("app",constellation=self.constellation)
self.binary_source = BinarySource()
self.channel = ApplyFlatFadingChannel(add_awgn=True)
self.neural_receiver = NeuralReceiver() # the frozen layer
self.encoder = encoder = LDPC5GEncoder(k, n)
self.NN_decoder = NN_decoder() # new trainable layer
self.bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)
self.acc = tf.keras.metrics.BinaryAccuracy()
@tf.function
def __call__(self, batch_size, ebno_db):
coderate = k/n
bits = self.binary_source([batch_size, num_tx_ant, k])
codewords = self.encoder(bits)
x = self.mapper(codewords)
no = ebnodb2no(ebno_db,num_bits_per_symbol,coderate)
channel_shape = [tf.shape(x)[0], num_rx_ant, num_tx_ant]
h = complex_normal(channel_shape)
y = self.channel([x, h, no])
x_hat, no_eff = self.neural_receiver(y,h) # already trainable layer to be frozen
llr = self.demapper([x_hat, no_eff])
bits_hat = self.NN_decoder(llr) # new trainable layer
if self.training:
loss = self.bce(codewords, bits_hat)
acc = self.acc(codewords, bits_hat)
return loss, acc
else:
return bits, bits_hat
CodePudding user response:
#If it's the last layer then simply put the "-1" but if you don't know then write the name of the layer then
for layer in model1.layers[-1].submodules:
layer.trainable = False
#Now append your model, after which node you wanna append your node mention that, I am appending after the last node, So I wrote -1.
x= model1.layers[-1](_input)
x = tf.keras.layers.Dense(...)(x)
...
...
...
model = tf.keras.Model(inputs, x)
CodePudding user response:
for layer in model1.layers[-1].submodules: layer.trainable = False
#Now append your model, after which node you wanna append your node mention that, I am appending after the last node, So I wrote -1. x= model1.layers-1 x = tf.keras.layers.Dense(...)(x) ... ... ... model = tf.keras.Model(inputs, x)