Home > Blockchain >  Pytorch: Set indexes in a tensor based on a list of tensor indices
Pytorch: Set indexes in a tensor based on a list of tensor indices

Time:07-20

Is there a way to efficiently set the values of a tensor based on a tensor of indices and a tensor of values?

tensor_to_change = tensor([[-36.9127, -45.6596, -47.1595],
        [-36.9409, -45.7024, -47.2050],
        [-36.9865, -45.7665, -47.2711],
        [-36.3202, -36.9561, -47.2066],
        [-36.2929, -36.9333, -47.1702]]
tensor_of_indices = tensor([[0],
        [0],
        [0],
        [1],
        [1]])
tensor_of_values = tensor([[-37.9409],
        [-38.4865],
        [-36.9561],
        [-34.9561],
        [-38.7562]])

I can accomplish this in with a for loop, but this step then becomes really slow:

for i, a in enumerate(tensor_of_indices):
    tensor_to_change[i][a] = tensor_of_values[i]

Is there a torch function which can do this faster?

CodePudding user response:

Try this:

rows = torch.arange(tensor_to_change.size(0))
cols = tensor_of_indices.squeeze()
tensor_to_change[rows, cols] = tensor_of_values.squeeze()

Output:

tensor_to_change
>tensor([[-37.9409, -45.6596, -47.1595],
        [-38.4865, -45.7024, -47.2050],
        [-36.9561, -45.7665, -47.2711],
        [-36.3202, -34.9561, -47.2066],
        [-36.2929, -38.7562, -47.1702]])

  • Related