Home > Back-end >  Pytorch/NumPy batched submatrix indexing
Pytorch/NumPy batched submatrix indexing

Time:03-11

There's a single source (square) matrix L of shape (N, N)

import torch as pt
import numpy as np

N = 4
L = pt.arange(N*N).reshape(N, N)  # or np.arange(N*N).reshape(N, N)
L = tensor([[ 0,  1,  2,  3],
            [ 4,  5,  6,  7],
            [ 8,  9, 10, 11],
            [12, 13, 14, 15]])

and a matrix (vector of vectors) of boolean masks m of shape (K, N) according to which I'd like to extract submatrices from L.

K = 3
m = tensor([[ True,  True, False, False],
            [False,  True,  True, False],
            [False,  True, False,  True]])

I know how to extract a single submatrix using a single mask vector by calling L[m[i]][:, m[i]] for any i. So, for example, for i=0, we'd get

tensor([[ 0,  1],
        [ 4,  5]])

but I need to perform the operation along the entire "batch" dimension. The end result I'm looking for then could be achieved by

res = []
for i in range(K):
    res.append(L[m[i]][:, m[i]])
output = pt.stack(res)

however, I hope there is a better solution excluding the for loop. I realize that the for loop solution itself would crash if the sum of m along the last dimension (dim/axis=1) wasn't constant, but if I can guarantee that it is, is there a better solution? If there isn't, would changing the selector representation help? I chose boolean masks for convenience, but I prefer better performance.

CodePudding user response:

Notice that you can get the first square by indexing together with broadcasting:

r = torch.tensor([0,1])
L[r[:,None], r]

output:

tensor([[0, 1],
        [4, 5]])

The same principle can be applied to the second square:

r = torch.tensor([1,2])
L[r[:,None], r]

output:

tensor([[ 5,  6],
        [ 9, 10]])

In combination you get:

i = torch.tensor([[0, 1], [1, 2]])
L[i[:,:,None], i[:,None]]

output:

tensor([[[ 0,  4],
         [ 1,  5]],

        [[ 5,  9],
         [ 6, 10]]])

All 3 squares:

i = torch.tensor([
    [0, 1],
    [1, 2],
    [1, 3],
])
L[i[:,:,None], i[:,None]]

output:

tensor([[[ 0,  1],
         [ 4,  5]],

        [[ 5,  6],
         [ 9, 10]],

        [[ 5,  7],
         [13, 15]]])

to summarize, I would suggest using indices instead of a mask.

  • Related