Home > OS >  PyTorch "where" conditional -- RuntimeError: expected scalar type long long but found floa
PyTorch "where" conditional -- RuntimeError: expected scalar type long long but found floa

Time:09-28

I have 2 tensors with gradients:

a = tensor([[0.0000, 0.0000, 0.2716, 0.0000, 0.4049, 0.0000, 0.2126, 0.8649, 0.0000,
         0.0000]], grad_fn=<ReluBackward0>)

b = tensor([[0.5842, 0.4618, 0.4047, 0.5714, 0.4841, 0.5683, 0.4030, 0.3779, 0.4436,
         0.4365]], grad_fn=<SigmoidBackward>)

I'm trying to use the second tensor (b) as a threshold while maintaining the differentiability of the tensors:

torch.where(a < b, 0, a)

However, I'm getting an error

RuntimeError: expected scalar type long long but found float

I can convert the tensors to long with

a = torch.tensor([0.0000, 0.0736, 0.5220, 0.0000, 0.0000, 0.1783, 0.0000, 0.0000, 0.0000,
         0.0000]).type(torch.LongTensor)

b = torch.tensor([0.4596, 0.4635, 0.5073, 0.4358, 0.5551, 0.5089, 0.5348, 0.5573, 0.5656,
         0.5886]).type(torch.LongTensor)

And then the conditional operation works without an error:

torch.where(a < b, 0, a)

Though 1. it gives me the wrong answer (it just converts each tensor to zeros):

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

And 2. it losses the gradient.

I also tried it with 2 simple tensors:

a = torch.tensor([1,2,3,4])
threshold = torch.tensor([0.5,2.3,2.9,4.2])

torch.where(a < threshold, 0, a)

>>>tensor([1, 0, 3, 0])

And this seems to work (though I don't have a reference in regards to the gradients in this case or know why it works in this case and not the other as I need the first to work)

CodePudding user response:

I believe that where is non-differential (see here). I think that you can get similar effect using the following:

torch.nn.Sigmoid()(-1e5*(b - a)) * a

The idea is smooth out the non-differential step function (where) with a sigmoid. You can make it steeper by playing with the value (1e-5).

  • Related