Home > Software design >  Cutting a multidimensional numpy array in half alongside a selected axis
Cutting a multidimensional numpy array in half alongside a selected axis

Time:08-27

I have a multidimensional numpy array like this:

[
  [
    [1,2,3,4,5],
    [6,7,8,9,10],
    [11,12,13,14,15]
  ],
  [
    [16,17,18,19,20],
    [21,22,23,24,25],
    [26,27,28,29,30]
  ]
]

and would like to create a function that cuts it in half alongside a specified axis, without including the middle element in case the size is uneven. So if I say my_function(my_ndarray, 0), I want to get

[
  [
    [1,2,3,4,5],
    [6,7,8,9,10],
    [11,12,13,14,15]
  ]
]

for my_function(my_ndarray, 1) I want to get

[
  [
    [1,2,3,4,5]
  ],
  [
    [16,17,18,19,20]
  ]
]

and for my_function(my_ndarray, 2) I want to get

[
  [
    [1,2],
    [6,7],
    [11,12]
  ],
  [
    [16,17],
    [21,22],
    [26,27]
  ]
]

My first attempt involved the np.split() method, but it unfortunately runs into problems when the length of the axis is an uneven number and doesn't allow me to specify that I would like to omit. In theory I could make an if statement and cut away the last slice of the selected axis if this is the case, but I would like to know if there is a more efficient way to solve this problem.

CodePudding user response:

Given an axis axis and an array a, I think you can can do

def my_function(a, axis):
    l = a.shape[axis]//2
    return a.take(range(l), axis=axis)

Examples:

>>> my_function(a, 0)
array([[[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15]]])

>>> my_function(a, 1)
array([[[ 1,  2,  3,  4,  5]],

       [[16, 17, 18, 19, 20]]])

>>> my_function(a, 2)
array([[[ 1,  2],
        [ 6,  7],
        [11, 12]],

       [[16, 17],
        [21, 22],
        [26, 27]]])

CodePudding user response:

What about:

def slice_n(a, n):
    slices = [slice(None)]*a.ndim
    slices[n] = slice(0, a.shape[n]//2)
    return a[tuple(slices)]

slice_n(a, 0)
array([[[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15]]])

slice_n(a, 1)
array([[[ 1,  2,  3,  4,  5]],

       [[16, 17, 18, 19, 20]]])

slice_n(a, 2)
array([[[ 1,  2],
        [ 6,  7],
        [11, 12]],

       [[16, 17],
        [21, 22],
        [26, 27]]])

used input (a):

array([[[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15]],

       [[16, 17, 18, 19, 20],
        [21, 22, 23, 24, 25],
        [26, 27, 28, 29, 30]]])
  • Related