I have a 4D tensor, and I would like to get the argmax across the last two dimensions. torch.argmax
only accepts integers as the "dim" argument, not tuples.
How can I accomplish this?
Here's what I had in mind, but I can't figure out how to match up the dimensions of my two "indices" tensors. original_array
is shape [1, 512, 37, 59].
max_vals, indices_r = torch.max(original_array, dim=2)
max_vals, indices_c = torch.max(max_vals, dim=2)
indices = torch.hstack((indices_r, indices_c))
CodePudding user response:
As others mentioned, its best to flatten the last two dimensions and apply argmax
original_array = torch.rand(1, 512, 37, 59)
original_flatten = original_array.view(1, 512, -1)
_, max_ind = original_flatten.max(-1)
.. you will get the linear index of the maximum values. In case you want the 2D indecies of the maximum values, you can do "unflatten" the indecies using the number of columns
# 59 is the number of columns for the (37, 59) part
torch.stack([max_ind // 59, max_ind % 59], -1)
this will give you a (1, 512, 2)
where each last 2 dim contains 2D coordinates.
CodePudding user response:
You could flatten the last two dimensions with torch.flatten
and apply torch.argmax
on it:
>>> x = torch.rand(2,3,100,100)
>>> x.flatten(-2).argmax(-1)
tensor([[2660, 6328, 8166],
[5934, 5494, 9717]])