I have 2 tensors with gradients:
a = tensor([[0.0000, 0.0000, 0.2716, 0.0000, 0.4049, 0.0000, 0.2126, 0.8649, 0.0000,
0.0000]], grad_fn=<ReluBackward0>)
b = tensor([[0.5842, 0.4618, 0.4047, 0.5714, 0.4841, 0.5683, 0.4030, 0.3779, 0.4436,
0.4365]], grad_fn=<SigmoidBackward>)
I'm trying to use the second tensor (b
) as a threshold while maintaining the differentiability of the tensors:
torch.where(a < b, 0, a)
However, I'm getting an error
RuntimeError: expected scalar type long long but found float
I can convert the tensors to long
with
a = torch.tensor([0.0000, 0.0736, 0.5220, 0.0000, 0.0000, 0.1783, 0.0000, 0.0000, 0.0000,
0.0000]).type(torch.LongTensor)
b = torch.tensor([0.4596, 0.4635, 0.5073, 0.4358, 0.5551, 0.5089, 0.5348, 0.5573, 0.5656,
0.5886]).type(torch.LongTensor)
And then the conditional operation works without an error:
torch.where(a < b, 0, a)
Though 1. it gives me the wrong answer (it just converts each tensor to zeros):
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
And 2. it losses the gradient.
I also tried it with 2 simple tensors:
a = torch.tensor([1,2,3,4])
threshold = torch.tensor([0.5,2.3,2.9,4.2])
torch.where(a < threshold, 0, a)
>>>tensor([1, 0, 3, 0])
And this seems to work (though I don't have a reference in regards to the gradients in this case or know why it works in this case and not the other as I need the first to work)
CodePudding user response:
I believe that where
is non-differential (see here).
I think that you can get similar effect using the following:
torch.nn.Sigmoid()(-1e5*(b - a)) * a
The idea is smooth out the non-differential step function (where
) with a sigmoid. You can make it steeper by playing with the value (1e-5
).