Home > Mobile >  index selection in case of conflict in pytorch Argmax
index selection in case of conflict in pytorch Argmax

Time:08-30

I have been trying to learn tensor operations and this one has thrown me for a loop.
Let us say I have one tensor t:

    t = torch.tensor([
        [1,0,0,2],
        [0,3,3,0],
        [4,0,0,5]
    ], dtype  = torch.float32)

Now this is a rank 2 tensor and we can apply argmax for each rank/dimension. let us say we apply it for dim = 1

t.max(dim = 1)
(tensor([2., 3., 5.]), tensor([3, 2, 3]))

Now we can see that the result is as expected the tensor along dim =1 has 2,3,and 5 as the max elements. But there is a conflict on 3. There are two values that are exactly similar.
How is it resolved? is it arbitrarily chosen? Is there an order for selecting like L-R, higher index value?
I'd appreciate any insights into how this is resolved!

CodePudding user response:

That is a good question I stumbled over a couple of times myself. The simplest answer is that there are no guarantees whatsoever that torch.argmax (or torch.max(x, dim=k), which also returns indices when dim is specified) will return the same index consistently. Instead, it will return any valid index to the argmax value, possibly randomly. As this thread in the official forum discusses, this is considered to be desired behavior. (I know that there is another thread I read a while ago that makes this more explicit, but I cannot find it again).

Having said that, as this behavior was unacceptable to my usecase, I wrote the following functions that will find the left and rightmost indices (be aware that condition is a function-object you pass in):

def __consistent_args(input, condition, indices):
    assert len(input.shape) == 2, 'only works for batch x dim tensors along the dim axis'
    mask = condition(input).float() * indices.unsqueeze(0).expand_as(input)
    return torch.argmax(mask, dim=1)


def consistent_find_leftmost(input, condition):
    indices = torch.arange(input.size(1), 0, -1, dtype=torch.float, device=input.device)
    return __consistent_args(input, condition, indices)


def consistent_find_rightmost(input, condition):
    indices = torch.arange(0, input.size(1), 1, dtype=torch.float, device=input.device)
    return __consistent_args(input, condition, indices)

# one example:
consistent_find_leftmost(torch.arange(10).unsqueeze(0), lambda x: x>5)                                                                                                                                     
# will return: 
# tensor([6])

Hope they will help! (Oh, and please let me know if you have a better implementation that does the same)

  • Related