I am reading some code with python, which is about probability distribution. But in many cases, instead of using a framework like PyTorch to compute distribution, it uses the following code
def sample_energy_0(self, y, M):
device = next(self.parameters()).device
x = torch.randn(M, y.shape[0], self.latent_dim).to(device)
return x
def energy_0(self, x, y):
return (x**2).sum(axis=2, keepdims=True) / 2
How does the function sum(axis=2)
work here?
CodePudding user response:
x.sum(axis=2)
means for a tensor of shape (A, B, C, ...)
, the elements along the C
dimension are summed to yield a new tensor of shape (A, B, ...)
. keepdims=True
means that the C
dimension is "kept" as a dummy dimension of 1
:
>>> import torch
>>> x = torch.arange(0, 30).reshape(2,3,5)
>>> x
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]],
[[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]]])
>>> x.shape
torch.Size([2, 3, 5])
>>> x.sum(axis=2).shape
torch.Size([2, 3])
>>> x.sum(axis=2, keepdims=True).shape
torch.Size([2, 3, 1])
CodePudding user response:
The expression (x**2).sum(axis=2, keepdims=True) / 2
isn't calling the function sum()
, it's calling a method named .sum()
on some object.
The type of the object is unknown from looking at your code alone, but it seems likely that x
is expected to be a Tensor
, which has a .sum()
method, which you can find in the documentation. However, since a keyword parameter called axis
is passed, it's also quite possible the function expects a numpy array, and it is calling the numpy.sum()
.
The (x**2)
applies the **
(raising to a power) operation on x
and 2
and the result is typically the same type as a type of x
. So, the .sum()
method can be called on the result.
Note that you could call energy_0()
with x
being either a Tensor
or a numpy
array and it would work just fine, as long as the code in the function is compatible with the interface of the class of the object being passed.