I'm trying to convert some pytorch code to tensorflow. In the pytorch code they are adding some extra parameter on every module of a model using module.register_parameter(name, param)
. How can i covert this part of code on tensorflow?
Sample code on below:
for module_name, module in self.model.named_modules():
module.register_parameter(name, new_parameter)
CodePudding user response:
tf.Variable
is the equivalent of nn.Parameter
in PyTorch. tf.Variable
is mainly used to store model parameters since their values are constantly updated during training.
To use a tensor as a new model parameter, you need to convert it to tf.Variable
. You can check here how to create variables from tensors.
If you want to add a model parameter in TensorFlow inside the model itself, you could simply create a variable inside the model class and it will be automatically registered as a model parameter by TensorFlow.
If you want to add a tf.Variable
externally to a model as a model parameter, you could manually add it to the trainable_weights attribute of tf.keras.layers.Layer
by extending it like this -
model.layers[-1].trainable_weights.extend([new_parameter])