I have a dstack like this:
import numpy as np
a = np.array((1,2,6))
b = np.array((2,3,4))
c = np.array((8,3,0))
stack = np.dstack((a,b,c))
print(stack)
#[[[1 2 8]
#[2 3 3]
#[6 4 0]]]
and I want to filter out the lists where the 2 element is less then 1.
Something like this:
new_list = []
for i in stack:
for d in i[:,2]:
if d>=1:
new_list.append(d)
print(new_list) # [8,3]
Doing this only the 2 element is added, but I would like to have all the row, like this:
#[[[1 2 8]
#[2 3 3]]]
And if I append(i)
the result is not the desired one either.
CodePudding user response:
You don't need a loop, you can do it with slicing
print(stack[stack[:,2] >= 1])
Output
[[1 2 8]
[2 3 3]]
If you need it as
[[[1 2 8]]
[[2 3 3]]]
you can reshape
the result
stack = stack[stack[:,2] >= 1]
shape = stack.shape
print(stack[stack[:,2] >= 1].reshape((shape[0], 1, shape[1])))