Home > Mobile >  What is the alternative of pytorch module.register_parameter(name, param) in tensorflow?
What is the alternative of pytorch module.register_parameter(name, param) in tensorflow?

Time:05-18

I'm trying to convert some pytorch code to tensorflow. In the pytorch code they are adding some extra parameter on every module of a model using module.register_parameter(name, param). How can i covert this part of code on tensorflow?

Sample code on below:

for module_name, module in self.model.named_modules():
    module.register_parameter(name, new_parameter)
     

CodePudding user response:

tf.Variable is the equivalent of nn.Parameter in PyTorch. tf.Variable is mainly used to store model parameters since their values are constantly updated during training.

To use a tensor as a new model parameter, you need to convert it to tf.Variable. You can check here how to create variables from tensors.

If you want to add a model parameter in TensorFlow inside the model itself, you could simply create a variable inside the model class and it will be automatically registered as a model parameter by TensorFlow.

If you want to add a tf.Variable externally to a model as a model parameter, you could manually add it to the trainable_weights attribute of tf.keras.layers.Layer by extending it like this -

model.layers[-1].trainable_weights.extend([new_parameter])
  • Related