Home > Net >  The function of sum() in python
The function of sum() in python

Time:01-06

I am reading some code with python, which is about probability distribution. But in many cases, instead of using a framework like PyTorch to compute distribution, it uses the following code

def sample_energy_0(self, y, M):
    device = next(self.parameters()).device
    x = torch.randn(M, y.shape[0], self.latent_dim).to(device)
    return x

def energy_0(self, x, y):
    return (x**2).sum(axis=2, keepdims=True) / 2

How does the function sum(axis=2) work here?

CodePudding user response:

x.sum(axis=2) means for a tensor of shape (A, B, C, ...), the elements along the C dimension are summed to yield a new tensor of shape (A, B, ...). keepdims=True means that the C dimension is "kept" as a dummy dimension of 1:

>>> import torch
>>> x = torch.arange(0, 30).reshape(2,3,5)
>>> x
tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14]],

        [[15, 16, 17, 18, 19],
         [20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29]]])
>>> x.shape
torch.Size([2, 3, 5])
>>> x.sum(axis=2).shape
torch.Size([2, 3])
>>> x.sum(axis=2, keepdims=True).shape
torch.Size([2, 3, 1])

CodePudding user response:

The expression (x**2).sum(axis=2, keepdims=True) / 2 isn't calling the function sum(), it's calling a method named .sum() on some object.

The type of the object is unknown from looking at your code alone, but it seems likely that x is expected to be a Tensor, which has a .sum() method, which you can find in the documentation. However, since a keyword parameter called axis is passed, it's also quite possible the function expects a numpy array, and it is calling the numpy.sum().

The (x**2) applies the ** (raising to a power) operation on x and 2 and the result is typically the same type as a type of x. So, the .sum() method can be called on the result.

Note that you could call energy_0() with x being either a Tensor or a numpy array and it would work just fine, as long as the code in the function is compatible with the interface of the class of the object being passed.

  • Related