Home > front end >  How to quickly select a sub matrix in a 2-dimensional matrix using numpy?
How to quickly select a sub matrix in a 2-dimensional matrix using numpy?

Time:04-05

I have a 7×7 matrix and I don't want to use the loop to quickly slice out a submatrix.

matrix= array([[ 0,  1,  2,  3,  4,  5,  6],
   [ 7,  8,  9, 10, 11, 12, 13],
   [14, 15, 16, 17, 18, 19, 20],
   [21, 22, 23, 24, 25, 26, 27],
   [28, 29, 30, 31, 32, 33, 34],
   [35, 36, 37, 38, 39, 40, 41],
   [42, 43, 44, 45, 46, 47, 48]])

sub_matrix = array([[1,2,3], [16,17,18], [28,29,30], [39,40,41]])

Principle picture

In fact, this matrix is very large. I have a list of slices per row and a list of slices at the beginning of each column. It is very difficult to specify directly the columns slice list for all rows.

I tried the following method, but it gave me error:IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (4,) (4,3)

slice_row = [0, 2, 4, 5]
slice_col_start = [1,2,0,4]
interval = 3
slice_col = [np.arange(x,x interval) for x in slice_col_start]

matrix[slice_row, np.r_[slice_col]]

CodePudding user response:

If you have the indices you could do:

x = np.array([[1,2,3], [2,3,4], [0,1,2], [4,5,6]])
y = np.array([0, 2, 4, 5])
matrix[y[:,None], x]

output:

array([[ 1,  2,  3],
       [16, 17, 18],
       [28, 29, 30],
       [39, 40, 41]])

CodePudding user response:

It can be achieved by np.take_along_axis. If the cols array be given:

rows = np.array([0, 2, 4, 5], dtype=np.int32)
cols = np.array([[1,2,3], [2,3,4], [0,1,2], [4,5,6]])
result = np.take_along_axis(a[rows], cols, axis=1)

CodePudding user response:

Thanks to Kevin, I came up with a solution

import numpy as np
matrix = np.arange(7*7).reshape(7,7)
slice_row = np.array([0, 2, 4, 5])
slice_col_start = np.array([1,2,0,4])
interval = 3
slice_col = [np.arange(x,x interval).tolist() for x in slice_col_start]

sub_matrix =matrix[slice_row[:,None], slice_col]
print(sub_matrix)

output

[[ 1  2  3]
 [16 17 18]
 [28 29 30]
 [39 40 41]]

CodePudding user response:

In [11]: arr = np.arange(49).reshape(7,7)
In [12]: slice_row = [0, 2, 4, 5]
    ...: slice_col_start = [1,2,0,4]
    ...: interval = 3
In [13]: idx1 = np.array(slice_row)
In [14]: idx2 = np.array(slice_col_start)

Since the interval is fixed, we can use linspace to create all column indices with one call:

In [19]: idx3 = np.linspace(idx2,idx2 interval, interval, endpoint=False,dtype=int)
In [20]: idx3
Out[20]: 
array([[1, 2, 0, 4],
       [2, 3, 1, 5],
       [3, 4, 2, 6]])

Then it's just a matter of indexing:

In [21]: arr[idx1[:,None], idx3.T]
Out[21]: 
array([[ 1,  2,  3],
       [16, 17, 18],
       [28, 29, 30],
       [39, 40, 41]])

Or use broadcasted addition:

In [23]: idx2[:,None]   np.arange(3)
Out[23]: 
array([[1, 2, 3],
       [2, 3, 4],
       [0, 1, 2],
       [4, 5, 6]])

If the interval varies by row, we will have to use form of iteration to get the full list of column indices.

  • Related