I have a PyTorch tensor b
with the shape: torch.Size([10, 10, 51])
. I want to select one element between the 10 possible elements in the dimension d=1 (middle one) using a numpy array: a = np.array([0,1,2,3,4,5,6,7,8,9])
. this is just a random example.
I wanted to do:
b[:,a,:]
but that isn't working
CodePudding user response:
Your solution is likely torch.index_select
(docs)
You'll have to turn a
into a tensor first, though.
a_torch = torch.from_numpy(a)
answer = torch.index_select(b, 1, a_torch)
CodePudding user response:
An indexing of b
on the second axis using a
should do:
>>> b = torch.rand(10, 10, 51)
>>> a = np.array([0,1,2,3,4,5,6,7,8,9])
>>> b[:, a].shape
torch.Size([10, 10, 51])