I pretend to remove slices from the third dimension of a 4d numpy array if it's contains only zeros.
I have a 4d numpy array of dimensions [256,256,336,6]
and I need to delete the slices in the third dimension that only contains zeros. So the result would have a shape like this , e.g. [256,256,300,6]
if 36 slices are fully zeros. I have tried multiple approaches including for loops, np.delete
and all()
, any()
functions without success.
CodePudding user response:
I'm not an afficionado with numpy, but does this do what you want?
I take the following small example matrix with 4 dimensions all full of 1s and then I set some slices to zero:
import numpy as np
a=np.ones((4,4,5,2))
The shape of a
is:
>>> a.shape
(4, 4, 5, 2)
I will artificially set some of the slices in dimension 3 to zero:
a[:,:,0,:]=0
a[:,:,3,:]=0
I can find the indices of the slices with not all zeros by calculating sums (not very efficient, perhaps!)
indices = [i for i in range(a.shape[2]) if a[:,:,i,:].sum() != 0]
>>> indices
[1, 2, 4]
So, in your general case you could do this:
indices = [i for i in range(a.shape[2]) if a[:,:,i,:].sum() != 0]
a_new = a[:, :, indices, :].copy()
Then the shape of a_new is:
>>> anew.shape
(4, 4, 3, 2)
CodePudding user response:
You need to reduce on all axes but the one you are interested in.
An example using np.any()
where there are all-zero subarrays along the axis 1
(at position 0
and 2
):
import numpy as np
a=np.ones((2, 3, 2, 3))
a[:, 0, :, :] = a[:, 2, :, :] =0
mask = np.any(a, axis=(0, 2, 3))
new_a = a[:, mask, :, :]
print(new_a.shape)
# (2, 1, 2, 3)
print(new_a)
# [[[[1. 1. 1.]
# [1. 1. 1.]]]
#
#
# [[[1. 1. 1.]
# [1. 1. 1.]]]]
The same code parametrized and refactored as a function:
def remove_all_zeros(arr: np.ndarray, axis: int) -> np.ndarray:
red_axes = tuple(i for i in range(arr.ndim) if i != axis)
mask = np.any(arr, axis=red_axes)
slicing = tuple(slice(None) if i != axis else mask for i in range(arr.ndim))
return arr[slicing]
a = np.ones((2, 3, 2, 3))
a[:, 0, :, :] = a[:, 2, :, :] = 0
new_a = remove_all_zeros(a, 1)
print(new_a.shape)
# (2, 1, 2, 3)
print(new_a)
# [[[[1. 1. 1.]
# [1. 1. 1.]]]
#
#
# [[[1. 1. 1.]
# [1. 1. 1.]]]]