Home > Mobile >  Copy tensor elements of certain indices in PyTorch
Copy tensor elements of certain indices in PyTorch

Time:05-30

The desired operation is similar in spirit to torch.Tensor.index_copy, but a little different.

It's best explained with an example.

Tensor A has original values that we will copy:

[10, 20, 30]

Tensor B has indices of A:

[0, 1, 0, 1, 2, 1]

Tensor C has same length as B, containing the indexed values of A:

[10, 20, 10, 20, 30, 20]

What's a good way to make C from A and B in PyTorch, without using loops?

CodePudding user response:

Have you tried just indexing by A?

In [1]: import torch
  
In [2]: a = torch.tensor([20,30,40])

In [3]: b = torch.tensor([0,1,2,1,1,2,0,0,1,2])

In [4]: a[b]
Out[4]: tensor([20, 30, 40, 30, 30, 40, 20, 20, 30, 40])
  • Related