Home > Enterprise >  Delete all zeros slices from 4d numpy array
Delete all zeros slices from 4d numpy array

Time:09-10

I pretend to remove slices from the third dimension of a 4d numpy array if it's contains only zeros.

I have a 4d numpy array of dimensions [256,256,336,6] and I need to delete the slices in the third dimension that only contains zeros. So the result would have a shape like this , e.g. [256,256,300,6] if 36 slices are fully zeros. I have tried multiple approaches including for loops, np.delete and all(), any() functions without success.

CodePudding user response:

I'm not an afficionado with numpy, but does this do what you want?

I take the following small example matrix with 4 dimensions all full of 1s and then I set some slices to zero:

import numpy as np
a=np.ones((4,4,5,2))

The shape of a is:

>>> a.shape
(4, 4, 5, 2)

I will artificially set some of the slices in dimension 3 to zero:

a[:,:,0,:]=0
a[:,:,3,:]=0

I can find the indices of the slices with not all zeros by calculating sums (not very efficient, perhaps!)

indices = [i for i in range(a.shape[2]) if a[:,:,i,:].sum() != 0]
>>> indices
[1, 2, 4]

So, in your general case you could do this:

indices = [i for i in range(a.shape[2]) if a[:,:,i,:].sum() != 0]
a_new = a[:, :, indices, :].copy()

Then the shape of a_new is:

>>> anew.shape
(4, 4, 3, 2)

CodePudding user response:

You need to reduce on all axes but the one you are interested in.

An example using np.any() where there are all-zero subarrays along the axis 1 (at position 0 and 2):

import numpy as np


a=np.ones((2, 3, 2, 3))
a[:, 0, :, :] = a[:, 2, :, :] =0


mask = np.any(a, axis=(0, 2, 3))
new_a = a[:, mask, :, :]
print(new_a.shape)
# (2, 1, 2, 3)
print(new_a)
# [[[[1. 1. 1.]
#    [1. 1. 1.]]]
#
#
#  [[[1. 1. 1.]
#    [1. 1. 1.]]]]

The same code parametrized and refactored as a function:

def remove_all_zeros(arr: np.ndarray, axis: int) -> np.ndarray:
    red_axes = tuple(i for i in range(arr.ndim) if i != axis)
    mask = np.any(arr, axis=red_axes)
    slicing = tuple(slice(None) if i != axis else mask for i in range(arr.ndim))
    return arr[slicing]


a = np.ones((2, 3, 2, 3))
a[:, 0, :, :] = a[:, 2, :, :] = 0
new_a = remove_all_zeros(a, 1)
print(new_a.shape)
# (2, 1, 2, 3)
print(new_a)
# [[[[1. 1. 1.]
#    [1. 1. 1.]]]
#
#
#  [[[1. 1. 1.]
#    [1. 1. 1.]]]]
  • Related