Home > Blockchain >  PyTorch: count the number of tensor values that are near ( /- a tolerance) the values of a reference
PyTorch: count the number of tensor values that are near ( /- a tolerance) the values of a reference

Time:02-20

I have 2 tensors of an arbitrary shape with several dimensions.

  • target_tensor
  • predicted_tensor

I want to count the number of values in the predicted_tensor that are near to the values of the target tensor.

With a for loop it shall be something like this:

targets = torch.flatten(target_tensor)
predicted = torch.flatten(predicted_tensor)

correct_values = 0
tolerance = 0.1

for i, prediction in enumerate(predicted):
    target = targets[i]
    if (target - tolerance < prediction < target   tolerance):
        correct_values =  1

However, a for loop is not a really good idea for performances.

I'm looking for a vectorized solution. I tried :

torch.sum(target - tolerance < prediction < target   tolerance)

But I got:

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In Julia it would be just adding a dot to precise that it is element wise.

Any idea on how to implement that with PyTorch with a short vectorized solution?

Thanks

CodePudding user response:

I think you are looking for torch.isclose:

correct_values = torch.isclose(prediction, target, atol=tolerance, rtol=0).sum()
  • Related