Home > Enterprise >  Python Unit Testing: How is function automatically called without providing function name?
Python Unit Testing: How is function automatically called without providing function name?

Time:05-30

I am looking into the codes in vanilla_vae here and its unit test test_vae here.

In the code snippet of test_vae below, I am confused as to how self.model(x) portion in test_loss(self) function directly calls VanillaVAE class's forward method without mentioning the function name. Could anyone provide me insight on this?

def setUp(self) -> None:
    # self.model2 = VAE(3, 10)
    self.model = VanillaVAE(3, 10)

def test_loss(self):
    x = torch.randn(16, 3, 64, 64)

    result = self.model(x)
    loss = self.model.loss_function(*result, M_N = 0.005)
    print(loss)

CodePudding user response:

This is because vanilla_vae inherits from BaseVAE, which inherits from nn.Module.

nn.Module contains a __call__ method, which is a built in method that makes classes callable.

This calls _call_impl where the forward function is referenced.

CodePudding user response:

This behavior depends on the torch.nn.Module. That is the PyTorch base class for creating neural networks. In the forward function, you define how your model is going to be run, from input to output.

This means that every time you pass an input to your model, the forward function is called automatically and it returns what it is defined. In this case, as I can see from your link, a List[Tensor]:

 def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

Here you can also find a couple of examples on how the nn package is used from PyTorch.

  • Related