I'm trying to reverse the order of the rows in a tensor that I create. I have tried with tensorflow and pytorch. Only thing I have found is the torch.flip() method. This does not work as it reverses not only the order of the rows, but also all of the elements in each row. I want the elements to remain the same. Is there an array operation of this to index the integers? For instance:
tensor_a = [1, 2, 3]
[4, 5, 6]
[7, 8, 9]
I want it to be returned as:
[7, 8, 9]
[4, 5, 6]
[1, 2, 3]
however, torch.flip(tensor_a) =
[9, 8, 7]
[6, 5, 4]
[3, 2, 1]
Anyone have any suggestions?
CodePudding user response:
According to documentation torch.flip
has argument dims
, which control what axis to be flipped. In this case torch.flip(tensor_a, dims=(0,))
will return expected result. Also torch.flip(tensor_a)
will reverse all tensor, and torch.flip(tensor_a, dims=(1,))
will reverse every row, like [1, 2, 3] --> [3, 2, 1]
.
CodePudding user response:
You need to set axis to flip on it like below:
a = torch.arange(9).view(3,3)
# tensor([[0, 1, 2],
# [3, 4, 5],
# [6, 7, 8]])
torch.flip(a, [0])
# tensor([[6, 7, 8],
# [3, 4, 5],
# [0, 1, 2]])
torch.flip(a, [1])
# tensor([[2, 1, 0],
# [5, 4, 3],
# [8, 7, 6]])
torch.flip(a, [0,1])
# tensor([[8, 7, 6],
# [5, 4, 3],
# [2, 1, 0]])
As you can see you need torch.flip(a, [0])