Home > Net >  Indexing a multi-dimensional tensor using only one dimension
Indexing a multi-dimensional tensor using only one dimension

Time:10-05

I have a PyTorch tensor b with the shape: torch.Size([10, 10, 51]). I want to select one element between the 10 possible elements in the dimension d=1 (middle one) using a numpy array: a = np.array([0,1,2,3,4,5,6,7,8,9]). this is just a random example.

I wanted to do: b[:,a,:] but that isn't working

CodePudding user response:

Your solution is likely torch.index_select (docs)

You'll have to turn a into a tensor first, though.

a_torch = torch.from_numpy(a)
answer = torch.index_select(b, 1, a_torch)

CodePudding user response:

An indexing of b on the second axis using a should do:

>>> b = torch.rand(10, 10, 51)
>>> a = np.array([0,1,2,3,4,5,6,7,8,9])

>>> b[:,  a].shape
torch.Size([10, 10, 51])
  • Related