Home > Blockchain >  Pytorch derivative calculation
Pytorch derivative calculation

Time:01-27

I have this simple pytorch code:

x = torch.arange(3,dtype=float)
x.requires_grad_(True)
y = 3*x   x.sum()
y.backward(torch.ones(3))
x.grad

This gives me [6,6,6], but shouldn't it be [4,4,4] ? Because if we have f(x)=3 * x0 3 * x1 3 * x2 x0 x1 x2, partial derivatives would be 3 1=4 ?

CodePudding user response:

The result is correct, and here is why.

I will refer to the first element of your results, and you can extend to the other elements. You want to compute dy1/dx1, but this is not the correct way. The result your code computes is dy1/dx1 dy2/dx1 dy3/dx1.

The ones you pass in the .backward implies that the result computed would be dot_product(ones, dy/dx). Note that dy/dx is a 3x3 matrix.

  • Related