Home > Back-end >  Pytorch argmax across multiple dimensions
Pytorch argmax across multiple dimensions

Time:10-11

I have a 4D tensor, and I would like to get the argmax across the last two dimensions. torch.argmax only accepts integers as the "dim" argument, not tuples.

How can I accomplish this?

Here's what I had in mind, but I can't figure out how to match up the dimensions of my two "indices" tensors. original_array is shape [1, 512, 37, 59].

max_vals, indices_r = torch.max(original_array, dim=2)
max_vals, indices_c = torch.max(max_vals, dim=2)
indices = torch.hstack((indices_r, indices_c))

CodePudding user response:

As others mentioned, its best to flatten the last two dimensions and apply argmax

original_array = torch.rand(1, 512, 37, 59)
original_flatten = original_array.view(1, 512, -1)
_, max_ind = original_flatten.max(-1)

.. you will get the linear index of the maximum values. In case you want the 2D indecies of the maximum values, you can do "unflatten" the indecies using the number of columns

# 59 is the number of columns for the (37, 59) part
torch.stack([max_ind // 59, max_ind % 59], -1)

this will give you a (1, 512, 2) where each last 2 dim contains 2D coordinates.

CodePudding user response:

You could flatten the last two dimensions with torch.flatten and apply torch.argmax on it:

>>> x = torch.rand(2,3,100,100)
>>> x.flatten(-2).argmax(-1)
tensor([[2660, 6328, 8166],
        [5934, 5494, 9717]])
  • Related